Last active
November 27, 2025 11:06
-
-
Save sepiabrown/b9dc366643a48285fdea331dccc5377e to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python | |
| # -*- coding: utf-8 -*- | |
| """ | |
| Consolidated Anomaly Detector with Batched TTA | |
| This module provides a modern, consolidated implementation of anomaly detection | |
| with DeCo-Diff using efficient batched Test-Time Augmentation (TTA). | |
| Key Features: | |
| - Combined Batch TTA: Processes B images × N shifts together (~6.6× speedup) | |
| - Configurable shifts: pad_px, stride, directions parameters | |
| - Clean architecture without legacy code | |
| - Designed to replace DecodiffEvaluator and DecodiffEvaluateProcessor | |
| """ | |
| import torch | |
| import numpy as np | |
| import pandas as pd | |
| from skimage.transform import resize | |
| from skimage import measure | |
| from sklearn.metrics import average_precision_score, auc, roc_auc_score, f1_score | |
| from torch.utils.data import DataLoader | |
| from torchvision import transforms | |
| import torch.nn.functional as F | |
| from glob import glob | |
| from pathlib import Path | |
| import json | |
| import matplotlib | |
| matplotlib.use('Agg') # Force non-GUI backend for thread safety (prevents TkAgg threading issues on Windows) | |
| import matplotlib.pyplot as plt | |
| from tqdm import tqdm | |
| from PIL import Image as PILImage | |
| import os | |
| import cv2 | |
| import sys | |
| import logging | |
| import time | |
| import copy | |
| from typing import Optional, List | |
| # Configure UTF-8 encoding for output (Windows compatibility) | |
| if sys.platform.startswith('win'): | |
| import io | |
| # Only wrap if not already a TextIOWrapper to avoid double-wrapping | |
| if not isinstance(sys.stdout, io.TextIOWrapper): | |
| sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8') | |
| if not isinstance(sys.stderr, io.TextIOWrapper): | |
| sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8') | |
| from dioodmi.models import UNET_models | |
| from dioodmi.algorithms.diffusion import create_diffusion | |
| from diffusers.models import AutoencoderKL | |
| from dioodmi.data.MVTECDataLoader import MVTECDataset | |
| from dioodmi.data.VISADataLoader import VISADataset | |
| from scipy.ndimage import gaussian_filter | |
| from dioodmi.utils.plot_results import PlotResults as PR | |
| from dioodmi.utils.utils import setup_signal_handlers, setup_logging | |
| from dioodmi.data.temporal_dataset_factory import ( | |
| create_temporal_dataset_manager, | |
| create_temporal_dataloader | |
| ) | |
| from dioodmi.tasks.anomaly_detection.full_image_tiled_evaluator import ( | |
| process_full_image_evaluation_loop, | |
| stack_anomaly_maps, | |
| setup_resume_mode | |
| ) | |
| from typing import Tuple, List, Dict, Set, Any | |
| # ============================================================================ | |
| # Image Processing Utility Functions (Pure Functions) | |
| # ============================================================================ | |
| def path_to_safe_filename(image_path: str) -> str: | |
| """Convert file path to safe filename for saving""" | |
| safe_name = os.path.basename(image_path) | |
| # Remove extension and replace problematic characters | |
| safe_name = os.path.splitext(safe_name)[0] | |
| safe_name = safe_name.replace(' ', '_').replace('/', '_').replace('\\', '_') | |
| return safe_name | |
| def draw_patch_rectangles_on_image(base_img: np.ndarray, predicted_defective_set: set, | |
| ground_truth_defective: set, overlapping: set, | |
| patch_size: int = 256, grid_thickness: int = 1) -> np.ndarray: | |
| """Draw patch rectangles (TP/FP/FN) on top of an image""" | |
| img_copy = base_img.copy() | |
| h, w = img_copy.shape[:2] | |
| # Color definitions for different patch types | |
| colors = { | |
| 'TP': (0, 255, 0), # Green for True Positives (overlapping) | |
| 'FP': (255, 0, 0), # Red for False Positives (predicted only) | |
| 'FN': (0, 0, 255), # Blue for False Negatives (ground truth only) | |
| } | |
| def draw_rectangle(img, row, col, color, thickness): | |
| """Draw a single rectangle on image""" | |
| x1 = col * patch_size | |
| y1 = row * patch_size | |
| x2 = min(x1 + patch_size, w) | |
| y2 = min(y1 + patch_size, h) | |
| cv2.rectangle(img, (x1, y1), (x2-1, y2-1), color, thickness) | |
| # Draw True Positives (overlapping patches) - Green | |
| for row, col in overlapping: | |
| draw_rectangle(img_copy, row, col, colors['TP'], grid_thickness) | |
| # Draw False Positives (predicted but not in ground truth) - Red | |
| false_positives = predicted_defective_set - ground_truth_defective | |
| for row, col in false_positives: | |
| draw_rectangle(img_copy, row, col, colors['FP'], grid_thickness) | |
| # Draw False Negatives (ground truth but not predicted) - Blue | |
| false_negatives = ground_truth_defective - predicted_defective_set | |
| for row, col in false_negatives: | |
| draw_rectangle(img_copy, row, col, colors['FN'], grid_thickness) | |
| return img_copy | |
| def create_anomaly_visualization(anomaly_map: np.ndarray, is_binary: bool = False, | |
| threshold: float = 5.0, add_grid: bool = True, | |
| patch_size: int = 256) -> np.ndarray: | |
| """Create anomaly map visualization""" | |
| if is_binary: | |
| # Binary visualization: threshold and convert to inverted colormap | |
| # Auto-detect anomaly map range and normalize threshold accordingly | |
| amap_max = np.max(anomaly_map) | |
| if amap_max <= 1.0: | |
| # Anomaly map is in 0-1 range, normalize threshold from 0-255 to 0-1 | |
| normalized_threshold = threshold / 255.0 | |
| else: | |
| # Anomaly map is in 0-255 range, use threshold as-is | |
| normalized_threshold = threshold | |
| binary_map = (anomaly_map > normalized_threshold).astype(np.uint8) | |
| # Create RGB image: black for 0 (normal), white for 1 (anomaly) | |
| vis_img = np.zeros((binary_map.shape[0], binary_map.shape[1], 3), dtype=np.uint8) # Start with black | |
| vis_img[binary_map == 1] = [255, 255, 255] # Set anomaly pixels to white | |
| return vis_img | |
| else: | |
| # Continuous visualization: normalize and apply colormap | |
| normalized = np.clip(anomaly_map, 0, 1) | |
| # Convert to 0-255 range and apply jet colormap | |
| normalized_255 = (normalized * 255).astype(np.uint8) | |
| colored = cv2.applyColorMap(normalized_255, cv2.COLORMAP_JET) | |
| return cv2.cvtColor(colored, cv2.COLOR_BGR2RGB) | |
| def create_anomaly_overlay(original_img: np.ndarray, anomaly_map: np.ndarray, | |
| alpha: float = 0.8, is_binary: bool = False, | |
| threshold: float = 5.0) -> np.ndarray: | |
| """Create anomaly overlay on original image""" | |
| # Create anomaly visualization | |
| anomaly_vis = create_anomaly_visualization(anomaly_map, is_binary, threshold, add_grid=False) | |
| # Ensure both images have same dimensions | |
| h, w = original_img.shape[:2] | |
| if anomaly_vis.shape[:2] != (h, w): | |
| from skimage.transform import resize | |
| anomaly_vis = resize(anomaly_vis, (h, w), preserve_range=True, anti_aliasing=True).astype(np.uint8) | |
| # Create overlay | |
| overlay = cv2.addWeighted(original_img, alpha, anomaly_vis, 1-alpha, 0) | |
| return overlay | |
| def determine_image_status(patch_results: list) -> str: | |
| """Determine overall image status from patch results""" | |
| status_counts = {'TP': 0, 'FN': 0, 'FP': 0, 'TN': 0} | |
| for result in patch_results: | |
| status = result.get('status', 'TN') | |
| if status in status_counts: | |
| status_counts[status] += 1 | |
| # Image is positive if it has any defective patches (TP or FN) | |
| has_ground_truth_defects = status_counts['TP'] + status_counts['FN'] > 0 | |
| has_predicted_defects = status_counts['TP'] + status_counts['FP'] > 0 | |
| if has_ground_truth_defects and has_predicted_defects: | |
| return 'TP' # Correctly identified as defective | |
| elif has_ground_truth_defects and not has_predicted_defects: | |
| return 'FN' # Missed defective image | |
| elif not has_ground_truth_defects and has_predicted_defects: | |
| return 'FP' # False alarm | |
| else: | |
| return 'TN' # Correctly identified as normal | |
| def _generate_shifts( | |
| pad_px: int, | |
| stride: int, | |
| directions: tuple | |
| ) -> List[Tuple[int, int]]: | |
| """Generate shift coordinates based on parameters. | |
| Args: | |
| pad_px: Shift range (±pad_px pixels) | |
| stride: Shift increment (1=every pixel, 2=skip pixels) | |
| directions: ("h", "v", "diag") | ("all",) | ("rhombus",) | |
| Returns: | |
| Sorted list of (dx, dy) shift coordinates | |
| Number of shifts generated: | |
| - ("h", "v", "diag") with pad_px=4, stride=1: 33 shifts | |
| - ("h", "v", "diag") with pad_px=2, stride=1: 13 shifts | |
| - ("h", "v") with pad_px=4, stride=1: 17 shifts | |
| - ("all",) with pad_px=4, stride=1: 81 shifts | |
| - ("all",) with pad_px=4, stride=2: 25 shifts | |
| """ | |
| shifts = set() | |
| if "all" in directions or directions == "all": | |
| for dx in range(-pad_px, pad_px + 1, stride): | |
| for dy in range(-pad_px, pad_px + 1, stride): | |
| shifts.add((dx, dy)) | |
| elif "rhombus" in directions or directions == "rhombus": | |
| for dx in range(-pad_px, pad_px + 1, stride): | |
| for dy in range(-pad_px, pad_px + 1, stride): | |
| if abs(dx) + abs(dy) <= (pad_px + 1): | |
| shifts.add((dx, dy)) | |
| else: | |
| if "h" in directions: | |
| for dx in range(-pad_px, pad_px + 1, stride): | |
| shifts.add((dx, 0)) | |
| if "v" in directions: | |
| for dy in range(-pad_px, pad_px + 1, stride): | |
| shifts.add((0, dy)) | |
| if "diag" in directions: | |
| for d in range(-pad_px, pad_px + 1, stride): | |
| shifts.add((d, d)) | |
| shifts.add((d, -d)) | |
| # Always include center shift | |
| shifts.add((0, 0)) | |
| return sorted(shifts) | |
| def _get_shift_type_and_amount(dx: int, dy: int) -> Tuple[str, int]: | |
| """Determine shift type and amount from (dx, dy) coordinates. | |
| Args: | |
| dx: Horizontal shift in pixels | |
| dy: Vertical shift in pixels | |
| Returns: | |
| Tuple[str, int]: (shift_type, shift_amount) | |
| shift_type: 'none', 'h' (horizontal), 'v' (vertical), 'diag' (diagonal) | |
| shift_amount: Pixel offset amount | |
| Examples: | |
| >>> _get_shift_type_and_amount(0, 0) | |
| ('none', 0) | |
| >>> _get_shift_type_and_amount(3, 0) | |
| ('h', 3) | |
| >>> _get_shift_type_and_amount(0, -2) | |
| ('v', -2) | |
| >>> _get_shift_type_and_amount(2, 2) | |
| ('diag', 2) | |
| >>> _get_shift_type_and_amount(-3, 3) | |
| ('diag', 3) | |
| """ | |
| if dx == 0 and dy == 0: | |
| return 'none', 0 | |
| elif dy == 0: # Pure horizontal | |
| return 'h', dx | |
| elif dx == 0: # Pure vertical | |
| return 'v', dy | |
| else: # Diagonal (dx != 0 and dy != 0) | |
| # Use max absolute value as amount, preserve sign | |
| max_abs = max(abs(dx), abs(dy)) | |
| # Determine sign: use sign of the component with larger absolute value | |
| if abs(dx) > abs(dy): | |
| sign = 1 if dx > 0 else -1 | |
| else: | |
| sign = 1 if dy > 0 else -1 | |
| return 'diag', sign * max_abs | |
| def create_confusion_matrix_from_patch_results( | |
| patch_results_by_image: dict, | |
| output_dir: str, | |
| patch_size: int = 256 | |
| ) -> None: | |
| """ | |
| Create confusion matrix visualization from patch results. | |
| Args: | |
| patch_results_by_image: Dict mapping image_path -> list of patch results | |
| Each patch result has: grid_row, grid_col, anomaly_pixels, pred_score, status | |
| output_dir: Directory to save confusion matrix plot and JSON | |
| patch_size: Size of patches for grid coordinate calculation | |
| """ | |
| import json | |
| import matplotlib | |
| matplotlib.use('Agg') # Use non-interactive backend | |
| import matplotlib.pyplot as plt | |
| # Collect all patch results | |
| all_patch_results = [] | |
| for image_path, patches in patch_results_by_image.items(): | |
| all_patch_results.extend(patches) | |
| if not all_patch_results: | |
| print("No patch results provided for confusion matrix generation") | |
| return | |
| # Initialize confusion matrix counters | |
| all_TP = all_FP = all_FN = all_TN = 0 | |
| print(f"Creating confusion matrix for {len(patch_results_by_image)} images...") | |
| # Count status values from patch results | |
| for patch in all_patch_results: | |
| status = patch.get('status', 'TN') | |
| if status == "TP": | |
| all_TP += 1 | |
| elif status == "FP": | |
| all_FP += 1 | |
| elif status == "FN": | |
| all_FN += 1 | |
| elif status == "TN": | |
| all_TN += 1 | |
| total = all_TP + all_FP + all_FN + all_TN | |
| accuracy = (all_TP + all_TN) / total if total > 0 else 0 | |
| # Calculate additional metrics | |
| precision = all_TP / (all_TP + all_FP) if (all_TP + all_FP) > 0 else 0 | |
| recall = all_TP / (all_TP + all_FN) if (all_TP + all_FN) > 0 else 0 | |
| f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0 | |
| print("Confusion Matrix (patch-level):") | |
| print(f"TP: {all_TP}, FP: {all_FP}, FN: {all_FN}, TN: {all_TN}") | |
| print(f"Accuracy: {accuracy:.4f}") | |
| print(f"Precision: {precision:.4f}") | |
| print(f"Recall: {recall:.4f}") | |
| print(f"F1-Score: {f1_score:.4f}") | |
| # Create confusion matrix visualization | |
| cm = np.array([[all_TP, all_FN], [all_FP, all_TN]]) | |
| plt.figure(figsize=(8, 6)) | |
| plt.imshow(cm, interpolation='nearest', cmap='Blues') | |
| plt.title('Confusion Matrix (Patch-level)', fontsize=16, fontweight='bold') | |
| plt.colorbar() | |
| # Add text annotations | |
| thresh = cm.max() / 2. | |
| for i in range(2): | |
| for j in range(2): | |
| plt.text(j, i, format(cm[i, j], 'd'), | |
| ha="center", va="center", | |
| color="white" if cm[i, j] > thresh else "black", | |
| fontsize=14, fontweight='bold') | |
| # Set labels | |
| tick_marks = np.arange(2) | |
| plt.xticks(tick_marks, ['Defective', 'Normal'], fontsize=12) | |
| plt.yticks(tick_marks, ['Defective', 'Normal'], fontsize=12) | |
| plt.ylabel('True Label', fontsize=12) | |
| plt.xlabel('Predicted Label', fontsize=12) | |
| # Add metrics text | |
| metrics_text = f'Accuracy: {accuracy:.4f}\nPrecision: {precision:.4f}\nRecall: {recall:.4f}\nF1-Score: {f1_score:.4f}' | |
| plt.figtext(0.02, 0.02, metrics_text, fontsize=10, bbox=dict(boxstyle="round,pad=0.3", facecolor="lightgray")) | |
| plt.tight_layout() | |
| # Save the confusion matrix plot | |
| os.makedirs(output_dir, exist_ok=True) | |
| cm_plot_path = os.path.join(output_dir, "confusion_matrix.png") | |
| plt.savefig(cm_plot_path, dpi=300, bbox_inches='tight') | |
| plt.close() | |
| print(f"Confusion matrix plot saved to: {cm_plot_path}") | |
| # Save detailed results to file | |
| result = { | |
| "TP": all_TP, | |
| "FP": all_FP, | |
| "FN": all_FN, | |
| "TN": all_TN, | |
| "accuracy": accuracy, | |
| "precision": precision, | |
| "recall": recall, | |
| "f1_score": f1_score, | |
| "total_patches": total | |
| } | |
| with open(os.path.join(output_dir, "confusion_matrix.json"), "w") as f: | |
| json.dump(result, f, indent=2) | |
| print(f"Confusion matrix results saved to: {os.path.join(output_dir, 'confusion_matrix.json')}") | |
| # ============================================================================ | |
| # Checkpoint Manager | |
| # ============================================================================ | |
| class CheckpointManager: | |
| """Manages checkpoint and organized image saving functionality""" | |
| def __init__(self, results_dir: str, enable_checkpointing: bool = False): | |
| self.results_dir = Path(results_dir) | |
| self.enable_checkpointing = enable_checkpointing | |
| self.marked_images_dir = self.results_dir / "marked_images" | |
| self.evaluation_results_dir = self.results_dir / "evaluation_results" | |
| # Create base directories | |
| self.marked_images_dir.mkdir(parents=True, exist_ok=True) | |
| self.evaluation_results_dir.mkdir(parents=True, exist_ok=True) | |
| # Setup organized folder structure | |
| self._setup_directory_structure() | |
| def _setup_directory_structure(self): | |
| """Create only base directory structure - subfolders created on demand""" | |
| # Only create the base marked_images_dir and evaluation_results_dir | |
| # Subfolders will be created on-demand when files are actually saved | |
| pass | |
| def get_status_folder(self, status: str) -> Path: | |
| """Get path for status-specific folder""" | |
| return self.marked_images_dir / status | |
| def get_image_level_folder(self) -> Path: | |
| """Get path for image_level folder""" | |
| return self.marked_images_dir / 'image_level' | |
| # ============================================================================ | |
| # Anomaly Detector (Consolidated Class) | |
| # ============================================================================ | |
| class AnomalyDetector: | |
| """ | |
| Modern Anomaly Detector with Batched TTA | |
| This class consolidates functionality from DecodiffEvaluator and | |
| DecodiffEvaluateProcessor into a single, efficient implementation. | |
| Key Features: | |
| - Batched TTA for ~6.6× speedup (Option B: Combined Batch) | |
| - Configurable shift parameters (pad_px, stride, directions) | |
| - Clean architecture without legacy code | |
| """ | |
| def __init__(self, args): | |
| # Setup DDP-compatible attributes (for signal handler compatibility) | |
| # AnomalyDetector doesn't use DDP, but we set these for consistency | |
| self.ddp = False | |
| self.local_rank = 0 | |
| self.is_main_process = True # Always True for single-process evaluation | |
| # Setup logging early (before other operations that might log) | |
| self.setup_logging(args) | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| if self.device == "cpu": | |
| self.log_message("⚠ GPU not found. Using CPU instead.", level='warning') | |
| # Set AMP dtype for mixed precision | |
| self.amp_dtype = ( | |
| torch.bfloat16 | |
| if (self.device == "cuda" and torch.cuda.is_bf16_supported()) | |
| else torch.float16 | |
| ) | |
| # Store difference augmentation parameters from args | |
| # clip_max is used to clip and normalize (divide by clip_max) to maintain output range [0, 1.0] | |
| self.image_diff_clip_max = getattr(args, 'image_diff_clip_max', 1.0) | |
| self.latent_diff_clip_max = getattr(args, 'latent_diff_clip_max', 1.0) | |
| # Store anomaly detection parameters from args | |
| self.anomaly_threshold = getattr(args, 'anomaly_threshold', 20) | |
| self.anomaly_min_area = getattr(args, 'anomaly_min_area', 10) | |
| # Validate clip_max values at initialization | |
| if self.image_diff_clip_max == 0: | |
| raise ValueError("image_diff_clip_max cannot be 0 (division by zero)") | |
| if self.latent_diff_clip_max == 0: | |
| raise ValueError("latent_diff_clip_max cannot be 0 (division by zero)") | |
| torch.set_grad_enabled(False) | |
| # Store filename strategy configuration | |
| self.regression_test_mode = getattr(args, 'regression_test_mode', False) | |
| self.object_class = getattr(args, 'object_category', None) | |
| self.filename_strategy_name = getattr(args, 'filename_strategy', None) | |
| self.setup_model(args) | |
| # File saving control (can be disabled for performance testing) | |
| self._disable_file_saving = False # Set to True to disable all file saving for performance testing | |
| self.save_plot_path = None if args.save_plot_path is None else Path(args.save_plot_path) | |
| if self.save_plot_path and not self._disable_file_saving: | |
| self.save_plot_path.mkdir(parents=True, exist_ok=True) | |
| # Initialize checkpoint manager for image saving if enabled | |
| # Note: CheckpointManager is used for saving evaluation images, so we enable it | |
| # when save_evaluation_images is True, regardless of checkpoint_enabled setting | |
| # TEMPORARY: Disable for performance testing | |
| self.checkpoint_manager = None | |
| if args.save_evaluation_images and self.save_plot_path and not self._disable_file_saving: | |
| self.checkpoint_manager = CheckpointManager( | |
| results_dir=str(self.save_plot_path), | |
| enable_checkpointing=True # Always enable for evaluation image saving | |
| ) | |
| # Setup save queue and worker thread if save_plot_path is set | |
| self.save_queue = None | |
| self.save_worker_thread = None | |
| if self.save_plot_path and not self._disable_file_saving: | |
| self._start_save_worker() | |
| # Setup signal handlers for graceful shutdown (using utility function) | |
| setup_signal_handlers( | |
| self, | |
| ddp=self.ddp, | |
| is_main_process=self.is_main_process, | |
| log_message_func=self.log_message # Use logging for AnomalyDetector | |
| ) | |
| def setup_logging(self, args): | |
| """Setup logging with proper Unicode support (using utility function)""" | |
| # Determine log file path - use save_plot_path if available, otherwise use current directory | |
| if hasattr(args, 'save_plot_path') and args.save_plot_path: | |
| log_dir = Path(args.save_plot_path).parent | |
| log_file_path = log_dir / "log.txt" | |
| else: | |
| # Fallback to current directory | |
| log_file_path = Path("log.txt") | |
| # Check for debug flag in args | |
| debug_enabled = getattr(args, 'debug', False) or getattr(args, 'debug_enabled', False) | |
| # Use utility function for logging setup | |
| self.logger = setup_logging( | |
| log_file_path=log_file_path, | |
| logger_name=__name__, | |
| is_main_process=self.is_main_process, # Always True for AnomalyDetector | |
| debug_enabled=debug_enabled | |
| ) | |
| # Store debug flag for conditional timing logs | |
| self.debug_enabled = debug_enabled | |
| # Log initialization message | |
| self.logger.info(f"Logging initialized - logs saved to: {log_file_path}") | |
| def log_message(self, message, level='info'): | |
| """ | |
| Log message to both console and file. | |
| Args: | |
| message: Message to log | |
| level: Log level ('info', 'warning', 'error', 'debug') | |
| """ | |
| if level == 'info': | |
| self.logger.info(message) | |
| elif level == 'warning': | |
| self.logger.warning(message) | |
| elif level == 'error': | |
| self.logger.error(message) | |
| elif level == 'debug': | |
| self.logger.debug(message) | |
| else: | |
| self.logger.info(message) # Default to info | |
| def setup_model(self, args): | |
| """Setup model and VAE for evaluation""" | |
| # Load VAE | |
| import os | |
| models_path = "./models" | |
| models_config = os.path.join(models_path, "config.json") | |
| current_dir = os.getcwd() | |
| self.log_message(f"\n{'='*80}") | |
| self.log_message(f"[DEBUG] VAE Model Loading") | |
| self.log_message(f"{'='*80}") | |
| self.log_message(f"Current working directory: {current_dir}") | |
| self.log_message(f"Checking for local models at: {os.path.abspath(models_path)}") | |
| self.log_message(f"Looking for config at: {os.path.abspath(models_config)}") | |
| self.log_message(f"Config exists: {os.path.exists(models_config)}") | |
| # Check current directory first | |
| if os.path.exists(models_config): | |
| self.log_message(f"✓ Found local VAE model, loading from: {os.path.abspath(models_path)}") | |
| vae_model = models_path | |
| self.vae = AutoencoderKL.from_pretrained(vae_model, local_files_only=True).to(self.device) | |
| else: | |
| # Fallback: check parent directory | |
| parent_models_path = "../models" | |
| parent_models_config = os.path.join(parent_models_path, "config.json") | |
| self.log_message(f"Checking parent directory for models at: {os.path.abspath(parent_models_path)}") | |
| self.log_message(f"Looking for config at: {os.path.abspath(parent_models_config)}") | |
| self.log_message(f"Config exists: {os.path.exists(parent_models_config)}") | |
| if os.path.exists(parent_models_config): | |
| self.log_message(f"✓ Found local VAE model in parent directory, loading from: {os.path.abspath(parent_models_path)}") | |
| vae_model = parent_models_path | |
| self.vae = AutoencoderKL.from_pretrained(vae_model, local_files_only=True).to(self.device) | |
| else: | |
| self.log_message(f"✗ Local VAE not found in current or parent directory, downloading from HuggingFace") | |
| vae_model = f"stabilityai/sd-vae-ft-{args.vae_type}" | |
| self.log_message(f"Downloading: {vae_model}") | |
| self.vae = AutoencoderKL.from_pretrained(vae_model).to(self.device) | |
| self.log_message(f"{'='*80}\n") | |
| self.vae.eval() | |
| # Find checkpoint | |
| try: | |
| if args.model_path != '': | |
| ckpt = args.model_path | |
| else: | |
| path = f"./DeCo-Diff_{args.dataset}_{args.object_category}_{args.model_size}_{args.crop_size}" | |
| try: | |
| ckpt = sorted(glob(f'{path}/last.pt'))[-1] | |
| except: | |
| ckpt = sorted(glob(f'{path}/*/last.pt'))[-1] | |
| except: | |
| raise Exception("Please provide the trained model's path using --model_path") | |
| # Setup model | |
| latent_size = int(args.crop_size) // 8 | |
| self.model = UNET_models[args.model_size](latent_size=latent_size, ncls=args.num_classes) | |
| # Load checkpoint with automatic device mapping (CUDA if available, else CPU) | |
| map_location = self.device if torch.cuda.is_available() else 'cpu' | |
| checkpoint = torch.load(ckpt, weights_only=False, map_location=map_location) | |
| # Check if EMA weights are available and use them if present | |
| if 'ema_state_dict' in checkpoint: | |
| # EMA state dict contains 'ema_model_state_dict' with the actual weights | |
| ema_state_dict = checkpoint['ema_state_dict'] | |
| if 'ema_model_state_dict' in ema_state_dict: | |
| state_dict = ema_state_dict['ema_model_state_dict'] | |
| self.log_message('Using EMA model weights for evaluation') | |
| else: | |
| # Fallback: assume ema_state_dict is the model state dict directly | |
| state_dict = ema_state_dict | |
| self.log_message('Using EMA model weights for evaluation (direct format)') | |
| else: | |
| # Fall back to regular model weights | |
| state_dict = checkpoint['model'] | |
| self.log_message('Using regular model weights (no EMA found in checkpoint)') | |
| load_result = self.model.load_state_dict(state_dict) | |
| self.log_message(str(load_result)) | |
| self.model.eval() | |
| self.model.to(self.device) | |
| self.log_message('✓ Model loaded') | |
| def _start_save_worker(self): | |
| """Start background worker thread for concurrent image saving""" | |
| from queue import Queue | |
| import threading | |
| from PIL import Image as PILImage | |
| self.save_queue = Queue() | |
| def worker(): | |
| """Background thread that processes image save requests from queue""" | |
| while True: | |
| try: | |
| item = self.save_queue.get() | |
| if item is None: # Sentinel value to stop worker | |
| break | |
| kind = item.get('kind', 'array') | |
| if kind == 'array': | |
| filepath = item['filepath'] | |
| image_array = item['image_array'] | |
| Path(filepath).parent.mkdir(parents=True, exist_ok=True) | |
| PILImage.fromarray(image_array, mode='L').save(filepath) | |
| elif kind == 'plot': | |
| # Handle plot saving in background thread | |
| from dioodmi.utils.plot_results import PlotResults as PR | |
| # Convert numpy arrays back to torch tensors (plot_shift_analysis expects tensors for orig_imgs) | |
| orig_imgs = item['orig_imgs'] | |
| if isinstance(orig_imgs, np.ndarray): | |
| orig_imgs = torch.from_numpy(orig_imgs).float() | |
| # averaged_anomaly_map can stay as numpy (it's used directly in imshow, not passed to rescale_range) | |
| averaged_anomaly_map = item['averaged_anomaly_map'] | |
| # Convert any numpy arrays in shift data back to tensors (for crop_original/crop_reconstruction) | |
| individual_shifts = item['individual_shifts'] | |
| for shift_data in individual_shifts: | |
| if 'crop_original' in shift_data and shift_data['crop_original'] is not None: | |
| if isinstance(shift_data['crop_original'], np.ndarray): | |
| shift_data['crop_original'] = torch.from_numpy(shift_data['crop_original']).float() | |
| if 'crop_reconstruction' in shift_data and shift_data['crop_reconstruction'] is not None: | |
| if isinstance(shift_data['crop_reconstruction'], np.ndarray): | |
| shift_data['crop_reconstruction'] = torch.from_numpy(shift_data['crop_reconstruction']).float() | |
| PR.plot_shift_analysis( | |
| individual_shifts=individual_shifts, | |
| averaged_anomaly_map=averaged_anomaly_map, | |
| batch_idx=item['batch_idx'], | |
| orig_imgs=orig_imgs, | |
| save_path=item['save_path'], | |
| max_shifts_to_show=item.get('max_shifts_to_show', 16), | |
| fuse_method=item.get('fuse_method', 'mean') | |
| ) | |
| else: | |
| print(f"Unknown save job kind: {kind}") | |
| except Exception as e: | |
| print(f"Error in save worker: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| finally: | |
| try: | |
| self.save_queue.task_done() | |
| except Exception as e: | |
| print(f"Error calling task_done: {e}") | |
| self.save_worker_thread = threading.Thread(target=worker, daemon=True) | |
| self.save_worker_thread.start() | |
| self.log_message("Started background image save worker thread") | |
| def wait_for_saves(self, timeout=300): | |
| """ | |
| Wait for all queued save operations to complete. | |
| Args: | |
| timeout: Maximum time to wait in seconds (default: 5 minutes) | |
| """ | |
| import time | |
| if not self.save_queue or not self.is_save_worker_alive(): | |
| return | |
| start_time = time.time() | |
| queue_size = self.get_save_queue_size() | |
| if queue_size > 0: | |
| print(f"Waiting for {queue_size} save operations to complete...") | |
| # Wait for queue to be empty | |
| while not self.save_queue.empty(): | |
| if self._shutdown_requested: | |
| print("Shutdown requested, stopping wait for saves...") | |
| break | |
| if time.time() - start_time > timeout: | |
| print(f"Warning: Save operations timed out after {timeout} seconds") | |
| print(f"Remaining queue size: {self.get_save_queue_size()}") | |
| break | |
| time.sleep(0.1) # Small delay to prevent busy waiting | |
| # Wait for all tasks to complete | |
| try: | |
| self.save_queue.join() | |
| if queue_size > 0: | |
| print("All save operations completed successfully") | |
| except Exception as e: | |
| print(f"Error waiting for saves to complete: {e}") | |
| def is_save_worker_alive(self): | |
| """Check if the save worker thread is still alive""" | |
| return self.save_worker_thread and self.save_worker_thread.is_alive() | |
| def get_save_queue_size(self): | |
| """Get the current size of the save queue""" | |
| return self.save_queue.qsize() if self.save_queue else 0 | |
| def shutdown_save_worker(self): | |
| """Stop the save worker thread with adaptive timeout based on queue size""" | |
| if self.save_queue and self.save_worker_thread: | |
| queue_size = self.get_save_queue_size() | |
| # Adaptive timeout: 5s base + 0.1s per queued item (min 10s) | |
| # Example: 200 items → 25s timeout, 500 items → 55s timeout | |
| timeout = max(10, queue_size * 0.1 + 5) | |
| if queue_size > 0: | |
| self.log_message(f"Stopping save worker ({queue_size} items remaining, timeout={timeout:.0f}s)...") | |
| self.save_queue.put(None) # Send sentinel to stop worker | |
| self.save_worker_thread.join(timeout=timeout) | |
| if self.save_worker_thread.is_alive(): | |
| remaining = self.get_save_queue_size() | |
| self.log_message(f"⚠️ Warning: Save worker timeout after {timeout:.0f}s ({remaining} items not saved)", level='warning') | |
| else: | |
| if queue_size > 0: | |
| self.log_message(f"✓ Save worker stopped ({queue_size} items saved)") | |
| else: | |
| self.log_message("Save worker thread stopped") | |
| def compute_pro(self, masks, amaps, num_th=200): | |
| """Compute the area under the curve of per-region overlapping (PRO) and 0 to 0.3 FPR""" | |
| assert isinstance(amaps, np.ndarray), "type(amaps) must be ndarray" | |
| assert isinstance(masks, np.ndarray), "type(masks) must be ndarray" | |
| assert amaps.ndim == 3, "amaps.ndim must be 3 (num_test_data, h, w)" | |
| assert masks.ndim == 3, "masks.ndim must be 3 (num_test_data, h, w)" | |
| assert amaps.shape == masks.shape, "amaps.shape and masks.shape must be same" | |
| unique_values = set(masks.flatten()) | |
| assert unique_values.issubset({0, 1, 0.0, 1.0}), f"set(masks.flatten()) must be subset of {{0, 1}}, got {unique_values}" | |
| assert isinstance(num_th, int), "type(num_th) must be int" | |
| df = pd.DataFrame([], columns=["pro", "fpr", "threshold"]) | |
| binary_amaps = np.zeros_like(amaps, dtype=bool) | |
| min_th = amaps.min() | |
| max_th = amaps.max() | |
| # Handle case where all anomaly map values are the same | |
| if min_th == max_th: | |
| # Return 0.0 since AUPRO integration between 0 and 0.3 for (0,0)-(1,1) curve is 0 | |
| return 0.0 | |
| # Use linspace for better numerical stability instead of arange | |
| # Generate thresholds from min_th to max_th (inclusive) with num_th points | |
| thresholds = np.linspace(min_th, max_th, num_th, endpoint=False) | |
| for th in thresholds: | |
| binary_amaps[amaps <= th] = 0 | |
| binary_amaps[amaps > th] = 1 | |
| pros = [] | |
| for binary_amap, mask in zip(binary_amaps, masks): | |
| for region in measure.regionprops(measure.label(mask)): | |
| axes0_ids = region.coords[:, 0] | |
| axes1_ids = region.coords[:, 1] | |
| tp_pixels = binary_amap[axes0_ids, axes1_ids].sum() | |
| pros.append(tp_pixels / region.area) | |
| inverse_masks = 1 - masks | |
| fp_pixels = np.logical_and(inverse_masks, binary_amaps).sum() | |
| fpr = fp_pixels / inverse_masks.sum() | |
| df.loc[len(df)] = [np.mean(pros) if len(pros) > 0 else 0, fpr, th] | |
| # Normalize FPR from 0 ~ 1 to 0 ~ 0.3 | |
| df = df[df["fpr"] < 0.3] | |
| if df.empty: | |
| return 0.0 | |
| df["fpr"] = df["fpr"] / df["fpr"].max() | |
| # Calculate area under PRO-FPR curve | |
| pro_auc = auc(df["fpr"], df["pro"]) | |
| return pro_auc | |
| def compute_pro_parallel(self, masks, amaps, num_th=200, n_jobs=-1): | |
| """Phase 5: Parallel PRO curve computation (1.3× faster). | |
| Computes per-region overlap (PRO) curve using parallel processing across thresholds. | |
| Args: | |
| masks: Ground truth masks [N, H, W] | |
| amaps: Anomaly maps [N, H, W] | |
| num_th: Number of thresholds (default: 200) | |
| n_jobs: Number of parallel jobs (-1 = all cores, default: -1) | |
| Returns: | |
| float: Area under PRO-FPR curve (0 to 0.3 FPR range) | |
| """ | |
| from joblib import Parallel, delayed | |
| from skimage import measure | |
| assert isinstance(amaps, np.ndarray), "type(amaps) must be ndarray" | |
| assert isinstance(masks, np.ndarray), "type(masks) must be ndarray" | |
| assert amaps.ndim == 3, "amaps.ndim must be 3 (num_test_data, h, w)" | |
| assert masks.ndim == 3, "masks.ndim must be 3 (num_test_data, h, w)" | |
| assert amaps.shape == masks.shape, "amaps.shape and masks.shape must be same" | |
| unique_values = set(masks.flatten()) | |
| assert unique_values.issubset({0, 1, 0.0, 1.0}), f"set(masks.flatten()) must be subset of {{0, 1}}, got {unique_values}" | |
| assert isinstance(num_th, int), "type(num_th) must be int" | |
| min_th = amaps.min() | |
| max_th = amaps.max() | |
| # Handle case where all anomaly map values are the same | |
| if min_th == max_th: | |
| return 0.0 | |
| thresholds = np.linspace(min_th, max_th, num_th, endpoint=False) | |
| def _process_threshold(th): | |
| """Process single threshold to compute PRO and FPR.""" | |
| # Binarize anomaly maps | |
| binary_amaps = (amaps > th).astype(bool) | |
| # Compute PRO for each region | |
| pros = [] | |
| for binary_amap, mask in zip(binary_amaps, masks): | |
| for region in measure.regionprops(measure.label(mask)): | |
| axes0_ids = region.coords[:, 0] | |
| axes1_ids = region.coords[:, 1] | |
| tp_pixels = binary_amap[axes0_ids, axes1_ids].sum() | |
| pros.append(tp_pixels / region.area) | |
| # Compute FPR | |
| inverse_masks = 1 - masks | |
| fp_pixels = np.logical_and(inverse_masks, binary_amaps).sum() | |
| fpr = fp_pixels / inverse_masks.sum() | |
| return { | |
| 'pro': np.mean(pros) if len(pros) > 0 else 0, | |
| 'fpr': fpr, | |
| 'threshold': th | |
| } | |
| # Parallel processing across thresholds | |
| results = Parallel(n_jobs=n_jobs, prefer='threads')( | |
| delayed(_process_threshold)(th) for th in thresholds | |
| ) | |
| # Convert to DataFrame | |
| df = pd.DataFrame(results) | |
| # Normalize FPR from 0 ~ 1 to 0 ~ 0.3 | |
| df = df[df["fpr"] < 0.3] | |
| if df.empty: | |
| return 0.0 | |
| df["fpr"] = df["fpr"] / df["fpr"].max() | |
| # Calculate area under PRO-FPR curve | |
| pro_auc = auc(df["fpr"], df["pro"]) | |
| return pro_auc | |
| def calculate_metrics(self, masks, amaps, parallel_pro=True): | |
| """Calculate AUROC, AUPRO, and F1-max metrics | |
| Args: | |
| masks: Ground truth masks | |
| amaps: Anomaly maps | |
| parallel_pro: Use parallel PRO computation (default: True) | |
| """ | |
| # Convert to numpy for sklearn metrics | |
| masks_flat = masks.flatten() | |
| amaps_flat = amaps.flatten() | |
| # Pixel-level metrics | |
| # Use sklearn roc_auc_score instead of anomalib (API compatibility) | |
| if len(set(masks_flat)) > 1: # Need both classes | |
| auroc_score = roc_auc_score(masks_flat, amaps_flat) | |
| else: | |
| auroc_score = 0.0 | |
| # F1 score (pixel-level) - using median threshold | |
| threshold = np.median(amaps_flat) | |
| pred_labels = (amaps_flat > threshold).astype(int) | |
| if len(set(masks_flat)) > 1: | |
| f1_score_val = f1_score(masks_flat, pred_labels) | |
| else: | |
| f1_score_val = 0.0 | |
| ap_score = average_precision_score(masks_flat, amaps_flat) if len(set(masks_flat)) > 1 else 0.0 | |
| # Sample-level metrics (image-level) | |
| labels = [] | |
| scores = [] | |
| for i in range(len(masks)): | |
| labels.append(1 if masks[i].max() > 0 else 0) | |
| scores.append(amaps[i].max()) | |
| if len(set(labels)) > 1: | |
| auroc_sp = roc_auc_score(labels, scores) | |
| else: | |
| auroc_sp = 0.0 | |
| ap_sp = average_precision_score(labels, scores) if len(set(labels)) > 1 else 0.0 | |
| # F1 score (sample-level) - using median threshold | |
| threshold = np.median(scores) | |
| pred_labels = [1 if s > threshold else 0 for s in scores] | |
| if len(set(labels)) > 1: | |
| f1_sp = f1_score(labels, pred_labels) | |
| else: | |
| f1_sp = 0.0 | |
| # AUPRO - use parallel version if enabled | |
| if parallel_pro: | |
| aupro_score = self.compute_pro_parallel(masks, amaps) | |
| else: | |
| aupro_score = self.compute_pro(masks, amaps) | |
| return auroc_score, aupro_score, f1_score_val, ap_score, auroc_sp, ap_sp, f1_sp | |
| def smooth_mask(self, mask, sigma=3): | |
| """Apply Gaussian smoothing to mask""" | |
| return gaussian_filter(mask, sigma=sigma) | |
| def calculate_anomaly_maps_batch_gpu(self, x0, encoded, image_samples, latent_samples, crop_size=256, | |
| image_diff_clip_max=1.0, latent_diff_clip_max=1.0): | |
| """Phase 3: GPU-accelerated anomaly map calculation (1.5-2× faster). | |
| Keeps all tensors on GPU, uses torch.nn.functional.interpolate for resizing. | |
| Returns torch tensors instead of numpy arrays for downstream GPU processing. | |
| Args: | |
| x0: Original decoded images [B, C, H, W] (GPU tensor) | |
| encoded: Encoded latents [B, C, H', W'] (GPU tensor) | |
| image_samples: Reconstructed images [B, C, H, W] (GPU tensor) | |
| latent_samples: Reconstructed latents [B, C, H', W'] (GPU tensor) | |
| crop_size: Target size for latent resizing | |
| image_diff_clip_max: Clipping threshold for image difference | |
| latent_diff_clip_max: Clipping threshold for latent difference | |
| Returns: | |
| dict: Anomaly maps as GPU tensors [B, H, W] | |
| """ | |
| import torch | |
| import torch.nn.functional as F | |
| # Validate clip_max values | |
| if image_diff_clip_max == 0: | |
| raise ValueError("image_diff_clip_max cannot be 0") | |
| if latent_diff_clip_max == 0: | |
| raise ValueError("latent_diff_clip_max cannot be 0") | |
| # Image difference (keep on GPU) | |
| image_difference = torch.abs(image_samples - x0).float().mean(dim=1) # [B, H, W] | |
| image_difference = torch.clamp(image_difference, 0.0, image_diff_clip_max) / image_diff_clip_max | |
| # Latent difference with GPU resize | |
| latent_difference = torch.abs(latent_samples - encoded).float().mean(dim=1) # [B, H', W'] | |
| latent_difference = torch.clamp(latent_difference, 0.0, latent_diff_clip_max) / latent_diff_clip_max | |
| # Resize latent to match image size using GPU interpolation | |
| latent_difference = F.interpolate( | |
| latent_difference.unsqueeze(1), # [B, 1, H', W'] | |
| size=(crop_size, crop_size), | |
| mode='bilinear', | |
| align_corners=False | |
| ).squeeze(1) # [B, H, W] | |
| # Compute anomaly maps (all GPU ops) | |
| pred_geometric = torch.sqrt(image_difference * latent_difference) | |
| pred_arithmetic = 0.5 * image_difference + 0.5 * latent_difference | |
| # Return GPU tensors (no .cpu().numpy() conversion!) | |
| return { | |
| 'anomaly_geometric': pred_geometric, # [B, H, W] torch tensor on GPU | |
| 'anomaly_arithmetic': pred_arithmetic, # [B, H, W] torch tensor on GPU | |
| 'latent_discrepancy': latent_difference, # [B, H, W] torch tensor on GPU | |
| 'image_discrepancy': image_difference, # [B, H, W] torch tensor on GPU | |
| } | |
| def calculate_anomaly_maps_batch(self, x0, encoded, image_samples, latent_samples, crop_size=256, | |
| image_diff_clip_max=1.0, latent_diff_clip_max=1.0): | |
| """Legacy CPU-based anomaly map calculation. | |
| Converts to numpy arrays immediately. Kept for backward compatibility. | |
| Args: | |
| x0: Original decoded images | |
| encoded: Encoded latents | |
| image_samples: Reconstructed images from diffusion | |
| latent_samples: Reconstructed latents from diffusion | |
| crop_size: Crop size for resizing | |
| image_diff_clip_max: Maximum value to clip image difference to (scale = 1.0 / clip_max) | |
| latent_diff_clip_max: Maximum value to clip latent difference to (scale = 1.0 / clip_max) | |
| """ | |
| B = x0.shape[0] | |
| # Validate clip_max values | |
| if image_diff_clip_max == 0: | |
| raise ValueError("image_diff_clip_max cannot be 0 (division by zero)") | |
| if latent_diff_clip_max == 0: | |
| raise ValueError("latent_diff_clip_max cannot be 0 (division by zero)") | |
| # Absolute differences (for anomaly detection) | |
| image_difference_raw = torch.abs(image_samples - x0).float().mean(dim=1) | |
| image_difference = image_difference_raw.detach().cpu().numpy() | |
| # Apply clipping and normalize by dividing by clip_max (maintains output range [0, 1.0]) | |
| # Example: clip_max=0.4 -> np.clip(image_difference, 0.0, 0.4) / 0.4 | |
| image_difference = np.clip(image_difference, 0.0, image_diff_clip_max) / image_diff_clip_max | |
| image_differences = image_difference | |
| latent_difference = torch.abs(latent_samples - encoded).float().mean(dim=1) | |
| latent_difference = latent_difference.detach().cpu().numpy() | |
| # Apply clipping and normalize by dividing by clip_max (maintains output range [0, 1.0]) | |
| # Example: clip_max=0.2 -> np.clip(latent_difference, 0.0, 0.2) / 0.2 | |
| latent_difference = np.clip(latent_difference, 0.0, latent_diff_clip_max) / latent_diff_clip_max | |
| latent_differences = [] | |
| for i in range(B): | |
| resized = resize(latent_difference[i], (crop_size, crop_size)) | |
| latent_differences.append(resized) | |
| latent_differences = np.stack(latent_differences, axis=0) | |
| # Signed differences (for grayscale visualization: use raw values, clip to -1~1 -> 0~1) | |
| # Compute on GPU, then move to CPU only when needed | |
| image_signed_diff = (image_samples - x0).float().mean(dim=1) | |
| # Clip to [-1, 1] range on GPU (images are in [-1, 1] range, so difference can be [-2, 2]) | |
| image_signed_diff = torch.clamp(image_signed_diff, -1.0, 1.0).detach().cpu().numpy() | |
| latent_signed_diff = (latent_samples - encoded).float().mean(dim=1) | |
| # Clip on GPU, then resize on CPU | |
| latent_signed_diff = torch.clamp(latent_signed_diff, -1.0, 1.0).detach().cpu().numpy() | |
| latent_signed_diffs = [] | |
| for i in range(B): | |
| resized = resize(latent_signed_diff[i], (crop_size, crop_size)) | |
| latent_signed_diffs.append(resized) | |
| latent_signed_diffs = np.stack(latent_signed_diffs, axis=0) | |
| pred_geometric = np.sqrt(image_differences * latent_differences) | |
| pred_arithmetic = 0.5 * image_differences + 0.5 * latent_differences | |
| return { | |
| 'anomaly_geometric': pred_geometric, | |
| 'anomaly_arithmetic': pred_arithmetic, | |
| 'latent_discrepancy': latent_differences, | |
| 'image_discrepancy': image_differences, | |
| 'image_signed_diff': image_signed_diff, # For grayscale: -1~1 range | |
| 'latent_signed_diff': latent_signed_diffs # For grayscale: -1~1 range | |
| } | |
| def _anomaly_shift_avg_combined_batch( | |
| self, | |
| x_batch: torch.Tensor, | |
| *, | |
| model_kwargs: dict, | |
| diffusion, | |
| reverse_steps: int, | |
| pad_px: int = 4, | |
| img_enlarge_px: int = 4, | |
| stride: int = 1, | |
| directions: tuple = ("h", "v", "diag"), | |
| shift_method: str = 'interpolate', | |
| fuse_method: str = 'mean', | |
| tta_batch_size: int = 8, | |
| eta: float = 0.0, | |
| crop_size: int = 256, | |
| ) -> Tuple[Dict[str, np.ndarray], List[List[Dict]]]: | |
| """Option B: Combined Batch TTA for maximum speedup. | |
| Processes B DataLoader images × N shifts in combined batches of size tta_batch_size. | |
| This batches BOTH the image dimension AND the shift dimension together. | |
| Args: | |
| x_batch: Input images [B, 3, H, W] from DataLoader | |
| model_kwargs: Model parameters (context, mask) | |
| diffusion: Diffusion model | |
| reverse_steps: Number of DDIM steps | |
| pad_px: Shift range (±pad_px pixels) | |
| img_enlarge_px: Image enlargement for interpolation | |
| stride: Shift increment | |
| directions: Shift directions to use | |
| shift_method: 'interpolate' or 'mirror' | |
| fuse_method: How to aggregate shifts ('mean', 'median', 'lowest', 'pct25', 'pct75') | |
| tta_batch_size: Number of shifts to process together per batch | |
| eta: DDIM eta parameter (0=deterministic) | |
| crop_size: Final crop size | |
| Returns: | |
| Tuple of: | |
| - Dictionary with keys: | |
| - 'anomaly_geometric': [B, H, W] | |
| - 'anomaly_arithmetic': [B, H, W] | |
| - 'latent_discrepancy': [B, H, W] | |
| - 'image_discrepancy': [B, H, W] | |
| - List[List[Dict]]: individual_shifts[img_idx] = list of shift dicts for plotting | |
| """ | |
| device = self.device | |
| B, C, H, W = x_batch.shape | |
| img_size = H | |
| # Validate pad_px | |
| if pad_px > img_enlarge_px: | |
| pad_px = img_enlarge_px | |
| print(f"⚠ Warning: pad_px clamped to img_enlarge_px={img_enlarge_px}") | |
| # Generate all shifts | |
| shifts = _generate_shifts(pad_px, stride, directions) | |
| num_shifts = len(shifts) | |
| # Use print() for high-frequency progress messages to avoid file I/O overhead | |
| print(f"Processing {B} images × {num_shifts} shifts in batches of {tta_batch_size}") | |
| # Prepare upsampled/padded input | |
| if shift_method == 'interpolate': | |
| aug_size = img_size + 2 * img_enlarge_px | |
| offset = img_enlarge_px | |
| x_up = F.interpolate( | |
| x_batch, size=(aug_size, aug_size), | |
| mode='bilinear', align_corners=False | |
| ).to(device) | |
| elif shift_method == 'mirror': | |
| aug_size = img_size + 2 * pad_px | |
| offset = pad_px | |
| x_up = F.pad( | |
| x_batch, (pad_px, pad_px, pad_px, pad_px), | |
| mode='reflect' | |
| ).to(device) | |
| else: | |
| raise ValueError(f"Unknown shift_method: {shift_method}. Choose 'interpolate' or 'mirror'") | |
| # Storage for all images × all shifts | |
| keys = ['anomaly_geometric', 'anomaly_arithmetic', 'latent_discrepancy', 'image_discrepancy'] | |
| signed_keys = ['image_signed_diff', 'latent_signed_diff'] | |
| # Structure: all_maps[key][img_idx][shift_idx] = canvas | |
| all_maps = {k: [[None for _ in range(num_shifts)] for _ in range(B)] for k in keys} | |
| all_signed_maps = {k: [[None for _ in range(num_shifts)] for _ in range(B)] for k in signed_keys} | |
| # Storage for individual shifts (for plot_shift_analysis) | |
| # Structure: individual_shifts[img_idx] = list of shift dicts | |
| individual_shifts = [[] for _ in range(B)] | |
| # Process shifts in mini-batches | |
| num_batches = (num_shifts + tta_batch_size - 1) // tta_batch_size | |
| for batch_idx in range(num_batches): | |
| batch_start_time = time.time() | |
| # Check for shutdown request | |
| if self._shutdown_requested: | |
| print(f"\n⚠️ TTA processing interrupted at batch {batch_idx}/{num_batches}. Shutting down...") | |
| break | |
| batch_start = batch_idx * tta_batch_size | |
| batch_end = min(batch_start + tta_batch_size, num_shifts) | |
| batch_shifts = shifts[batch_start:batch_end] | |
| print(f" Batch {batch_idx+1}/{num_batches}: processing shifts {batch_start}-{batch_end-1}") | |
| # Create combined batch: [B × len(batch_shifts), 3, H, W] | |
| prep_start = time.time() | |
| combined_crops = [] | |
| metadata = [] # Track (img_idx, shift_idx_in_full_list, shift, position) | |
| for img_idx in range(B): | |
| for shift_local_idx, (dx, dy) in enumerate(batch_shifts): | |
| shift_global_idx = batch_start + shift_local_idx | |
| y0 = offset + dy | |
| x0 = offset + dx | |
| y1 = y0 + img_size | |
| x1 = x0 + img_size | |
| # Skip invalid crops | |
| if not (0 <= y0 < y1 <= aug_size and 0 <= x0 < x1 <= aug_size): | |
| continue | |
| crop = x_up[img_idx:img_idx+1, :, y0:y1, x0:x1] | |
| combined_crops.append(crop) | |
| metadata.append({ | |
| 'img_idx': img_idx, | |
| 'shift_idx': shift_global_idx, | |
| 'shift': (dx, dy), | |
| 'position': (x0, y0, x1, y1) | |
| }) | |
| if not combined_crops: | |
| continue | |
| # Stack: [B*tta_batch_size, 3, H, W] | |
| combined_tensor = torch.cat(combined_crops, dim=0) | |
| prep_time = time.time() - prep_start | |
| if self.debug_enabled: | |
| self.log_message(f" [TIMING] Batch prep: {prep_time:.3f}s", level='debug') | |
| # Expand context tensor to match combined batch size | |
| context_start = time.time() | |
| # Each crop needs its corresponding category index | |
| combined_batch_size = combined_tensor.shape[0] | |
| if 'context' in model_kwargs and model_kwargs['context'] is not None: | |
| # Get original context shape [B, 1] | |
| original_context = model_kwargs['context'] | |
| # Validate original context | |
| if original_context is None: | |
| raise ValueError(f"original_context is None in batch {batch_idx}") | |
| if not isinstance(original_context, torch.Tensor): | |
| raise TypeError(f"original_context must be Tensor, got {type(original_context)}") | |
| # Build expanded context based on metadata to handle skipped crops correctly | |
| # metadata contains the img_idx for each crop in the combined batch | |
| expanded_context_list = [] | |
| for meta in metadata: | |
| img_idx = meta['img_idx'] | |
| # Validate img_idx is within bounds | |
| if img_idx >= original_context.shape[0]: | |
| raise IndexError(f"img_idx {img_idx} >= original_context.shape[0] {original_context.shape[0]}") | |
| # Get the context for this image | |
| img_context = original_context[img_idx:img_idx+1] # [1, 1] | |
| expanded_context_list.append(img_context) | |
| if expanded_context_list: | |
| expanded_context = torch.cat(expanded_context_list, dim=0) # [combined_batch_size, 1] | |
| else: | |
| # Fallback if no metadata (shouldn't happen, but be safe) | |
| num_shifts_in_batch = len(batch_shifts) | |
| expanded_context = original_context.repeat_interleave(num_shifts_in_batch, dim=0) | |
| if expanded_context.shape[0] != combined_batch_size: | |
| expanded_context = expanded_context[:combined_batch_size] | |
| # Validate expanded context | |
| if expanded_context is None: | |
| raise ValueError(f"expanded_context is None in batch {batch_idx}") | |
| if expanded_context.shape[0] != combined_batch_size: | |
| raise ValueError(f"expanded_context.shape[0] {expanded_context.shape[0]} != combined_batch_size {combined_batch_size}") | |
| model_kwargs_batch = { | |
| 'context': expanded_context, | |
| 'mask': model_kwargs.get('mask', None) | |
| } | |
| else: | |
| # If no context in model_kwargs, this is an error for class-conditional models | |
| raise ValueError(f"model_kwargs['context'] is missing or None in batch {batch_idx} - required for class-conditional model") | |
| context_time = time.time() - context_start | |
| if self.debug_enabled: | |
| self.log_message(f" [TIMING] Context expansion: {context_time:.3f}s", level='debug') | |
| # Process entire combined batch at once (FAST!) | |
| gpu_start = time.time() | |
| with torch.no_grad(): | |
| encode_start = time.time() | |
| encoded = self.vae.encode(combined_tensor).latent_dist.mean.mul_(0.18215) | |
| if torch.cuda.is_available(): | |
| torch.cuda.synchronize() | |
| encode_time = time.time() - encode_start | |
| if self.debug_enabled: | |
| self.log_message(f" [TIMING] VAE encode: {encode_time:.3f}s", level='debug') | |
| diffusion_start = time.time() | |
| latent_samples = diffusion.ddim_deviation_sample_loop( | |
| self.model, encoded.shape, noise=encoded, | |
| clip_denoised=False, start_t=reverse_steps, | |
| model_kwargs=model_kwargs_batch, progress=False, | |
| device=device, eta=eta | |
| ) | |
| if torch.cuda.is_available(): | |
| torch.cuda.synchronize() | |
| diffusion_time = time.time() - diffusion_start | |
| if self.debug_enabled: | |
| self.log_message(f" [TIMING] Diffusion sampling: {diffusion_time:.3f}s", level='debug') | |
| decode_start = time.time() | |
| image_samples = self.vae.decode(latent_samples / 0.18215).sample | |
| x0_decoded = self.vae.decode(encoded / 0.18215).sample | |
| if torch.cuda.is_available(): | |
| torch.cuda.synchronize() | |
| decode_time = time.time() - decode_start | |
| if self.debug_enabled: | |
| self.log_message(f" [TIMING] VAE decode: {decode_time:.3f}s", level='debug') | |
| gpu_time = time.time() - gpu_start | |
| if self.debug_enabled: | |
| self.log_message(f" [TIMING] Total GPU: {gpu_time:.3f}s", level='debug') | |
| # Calculate anomaly maps for entire batch | |
| anomaly_start = time.time() | |
| anomaly_batch = self.calculate_anomaly_maps_batch( | |
| x0_decoded, encoded, image_samples, latent_samples, | |
| crop_size=img_size, | |
| image_diff_clip_max=self.image_diff_clip_max, | |
| latent_diff_clip_max=self.latent_diff_clip_max | |
| ) | |
| anomaly_time = time.time() - anomaly_start | |
| if self.debug_enabled: | |
| self.log_message(f" [TIMING] Anomaly map calc: {anomaly_time:.3f}s", level='debug') | |
| # Distribute results back to each image | |
| distribute_start = time.time() | |
| for i, meta in enumerate(metadata): | |
| img_idx = meta['img_idx'] | |
| shift_idx = meta['shift_idx'] | |
| x0, y0, x1, y1 = meta['position'] | |
| # Place on canvas for this image/shift | |
| for key in keys: | |
| canvas = np.full((aug_size, aug_size), np.nan, dtype=np.float32) | |
| canvas[y0:y1, x0:x1] = anomaly_batch[key][i] | |
| all_maps[key][img_idx][shift_idx] = canvas | |
| # Place signed differences on canvas | |
| for key in signed_keys: | |
| if key in anomaly_batch: | |
| canvas = np.full((aug_size, aug_size), np.nan, dtype=np.float32) | |
| canvas[y0:y1, x0:x1] = anomaly_batch[key][i] | |
| all_signed_maps[key][img_idx][shift_idx] = canvas | |
| # Store individual shift data for plot_shift_analysis | |
| # Get the crop and reconstruction for this shift | |
| # Note: combined_crops[i] corresponds to the crop at index i in the combined batch | |
| if i < len(combined_crops): | |
| crop_original = combined_crops[i].detach().cpu() | |
| else: | |
| crop_original = None | |
| if i < image_samples.shape[0]: | |
| crop_reconstruction = image_samples[i:i+1].detach().cpu() | |
| else: | |
| crop_reconstruction = None | |
| individual_shifts[img_idx].append({ | |
| 'shift': meta['shift'], # Use shift from metadata, not undefined (dx, dy) | |
| 'position': (x0, y0, x1, y1), | |
| 'arithmetic_combined': anomaly_batch['anomaly_arithmetic'][i].copy(), | |
| 'crop_original': crop_original, | |
| 'crop_reconstruction': crop_reconstruction, | |
| 'shift_method': shift_method | |
| }) | |
| distribute_time = time.time() - distribute_start | |
| if self.debug_enabled: | |
| self.log_message(f" [TIMING] Result distribution: {distribute_time:.3f}s", level='debug') | |
| batch_total_time = time.time() - batch_start_time | |
| if self.debug_enabled: | |
| self.log_message(f" [TIMING] Batch {batch_idx+1} total: {batch_total_time:.3f}s", level='debug') | |
| # Average shifts for each image | |
| fuse_start = time.time() | |
| final_anomaly_maps = {} | |
| # Optimize: Batch resize operations on GPU if using interpolate | |
| if shift_method == 'interpolate': | |
| # Collect all averaged maps first, then batch resize on GPU | |
| all_averaged_maps = {} # key -> list of [H, W] arrays | |
| for key in keys: | |
| all_averaged_maps[key] = [] | |
| for img_idx in range(B): | |
| # Stack all shifts for this image (filter out None entries) | |
| valid_maps = [m for m in all_maps[key][img_idx] if m is not None] | |
| if not valid_maps: | |
| # No valid shifts (shouldn't happen) | |
| all_averaged_maps[key].append(np.zeros((aug_size, aug_size), dtype=np.float32)) | |
| continue | |
| stack = np.stack(valid_maps, axis=-1) # [H, W, num_valid_shifts] | |
| # Fuse across shifts using optimized partition-based method | |
| averaged = self._fuse_array(stack, fuse_method, axis=-1) | |
| all_averaged_maps[key].append(averaged) | |
| # Batch resize all maps on GPU at once (much faster) | |
| for key in keys: | |
| # Stack all averaged maps: [B, H, W] | |
| stacked_averaged = np.stack(all_averaged_maps[key], axis=0) | |
| # Convert to tensor and resize on GPU | |
| averaged_tensor = torch.from_numpy(stacked_averaged).float().unsqueeze(1).to(self.device) # [B, 1, H, W] | |
| resized_tensor = F.interpolate( | |
| averaged_tensor, size=(img_size, img_size), | |
| mode='bilinear', align_corners=False | |
| ) | |
| # Move back to CPU and remove channel dimension | |
| final_anomaly_maps[key] = resized_tensor.squeeze(1).cpu().numpy() # [B, H, W] | |
| else: | |
| # Mirror method: just slice, no resize needed (fast path) | |
| for key in keys: | |
| final_maps = [] | |
| for img_idx in range(B): | |
| # Stack all shifts for this image (filter out None entries) | |
| valid_maps = [m for m in all_maps[key][img_idx] if m is not None] | |
| if not valid_maps: | |
| # No valid shifts (shouldn't happen) | |
| final_maps.append(np.zeros((img_size, img_size), dtype=np.float32)) | |
| continue | |
| stack = np.stack(valid_maps, axis=-1) # [H, W, num_valid_shifts] | |
| # Fuse across shifts using optimized partition-based method | |
| averaged = self._fuse_array(stack, fuse_method, axis=-1) | |
| # Just slice the valid region (no resize needed for mirror) | |
| resized = averaged[pad_px:pad_px+img_size, pad_px:pad_px+img_size] | |
| final_maps.append(resized) | |
| final_anomaly_maps[key] = np.stack(final_maps, axis=0) | |
| fuse_time = time.time() - fuse_start | |
| if self.debug_enabled: | |
| self.log_message(f" [TIMING] Shift fusion: {fuse_time:.3f}s", level='debug') | |
| # Average signed differences for each image | |
| for key in signed_keys: | |
| final_maps = [] | |
| for img_idx in range(B): | |
| # Stack all shifts for this image (filter out None entries) | |
| valid_maps = [m for m in all_signed_maps[key][img_idx] if m is not None] | |
| if not valid_maps: | |
| # No valid shifts (shouldn't happen) | |
| final_maps.append(np.zeros((img_size, img_size), dtype=np.float32)) | |
| continue | |
| stack = np.stack(valid_maps, axis=-1) # [H, W, num_valid_shifts] | |
| # Fuse across shifts (use mean for signed differences) | |
| averaged = np.nanmean(stack, axis=-1) | |
| # Resize back to original size | |
| if shift_method == 'interpolate': | |
| averaged_tensor = torch.from_numpy(averaged).float().unsqueeze(0).unsqueeze(0) | |
| resized_tensor = F.interpolate( | |
| averaged_tensor, size=(img_size, img_size), | |
| mode='bilinear', align_corners=False | |
| ) | |
| resized = resized_tensor.squeeze().numpy() | |
| else: | |
| resized = averaged[pad_px:pad_px+img_size, pad_px:pad_px+img_size] | |
| final_maps.append(resized) | |
| final_anomaly_maps[key] = np.stack(final_maps, axis=0) | |
| return final_anomaly_maps, individual_shifts | |
| def _compute_tta_per_shift_fundamentals( | |
| self, | |
| x_batch: torch.Tensor, | |
| *, | |
| model_kwargs: dict, | |
| diffusion, | |
| reverse_steps: int, | |
| pad_px: int = 4, | |
| img_enlarge_px: int = 4, | |
| stride: int = 1, | |
| directions: tuple = ("h", "v", "diag"), | |
| shift_method: str = 'interpolate', | |
| tta_batch_size: int = 8, | |
| eta: float = 0.0, | |
| crop_size: int = 256, | |
| ) -> Dict[str, Any]: | |
| """Compute per-shift fundamentals for TTA (Test-Time Augmentation). | |
| Similar to _anomaly_shift_avg_combined_batch but stores fundamentals per-shift | |
| instead of computing and fusing anomaly maps. | |
| Args: | |
| x_batch: Input images [B, 3, H, W] from DataLoader | |
| model_kwargs: Model parameters (context, mask) | |
| diffusion: Diffusion model | |
| reverse_steps: Number of DDIM steps | |
| pad_px: Shift range (±pad_px pixels) | |
| img_enlarge_px: Image enlargement for interpolation | |
| stride: Shift increment | |
| directions: Shift directions to use | |
| shift_method: 'interpolate' or 'mirror' | |
| tta_batch_size: Number of shifts to process together per batch | |
| eta: DDIM eta parameter (0=deterministic) | |
| crop_size: Final crop size | |
| Returns: | |
| Dictionary with keys: | |
| - 'fundamentals': Dict[str, List[List[np.ndarray]]] | |
| Format: fundamentals[key][img_idx][shift_idx] = [1, C, H, W] | |
| Keys: 'x0', 'encoded', 'image_samples', 'latent_samples' | |
| - 'shifts': List[(dx, dy)] - shift coordinates | |
| - 'num_shifts': int - total number of shifts | |
| - 'directions': tuple - shift directions used | |
| - 'pad_px': int - shift range used | |
| - 'stride': int - stride used | |
| """ | |
| device = self.device | |
| B, C, H, W = x_batch.shape | |
| img_size = H | |
| # Validate pad_px | |
| if pad_px > img_enlarge_px: | |
| pad_px = img_enlarge_px | |
| print(f"⚠ Warning: pad_px clamped to img_enlarge_px={img_enlarge_px}") | |
| # Generate all shifts | |
| shifts = _generate_shifts(pad_px, stride, directions) | |
| num_shifts = len(shifts) | |
| self.log_message(f"Computing per-shift fundamentals for {B} images × {num_shifts} shifts") | |
| # Prepare upsampled/padded input | |
| if shift_method == 'interpolate': | |
| aug_size = img_size + 2 * img_enlarge_px | |
| offset = img_enlarge_px | |
| x_up = F.interpolate( | |
| x_batch, size=(aug_size, aug_size), | |
| mode='bilinear', align_corners=False | |
| ).to(device) | |
| elif shift_method == 'mirror': | |
| aug_size = img_size + 2 * pad_px | |
| offset = pad_px | |
| x_up = F.pad( | |
| x_batch, (pad_px, pad_px, pad_px, pad_px), | |
| mode='reflect' | |
| ).to(device) | |
| else: | |
| raise ValueError(f"Unknown shift_method: {shift_method}. Choose 'interpolate' or 'mirror'") | |
| # Storage for all images × all shifts | |
| # Structure: all_fundamentals[key][img_idx][shift_idx] = np.ndarray [1, C, H, W] | |
| all_fundamentals = { | |
| 'x0': [[None for _ in range(num_shifts)] for _ in range(B)], | |
| 'encoded': [[None for _ in range(num_shifts)] for _ in range(B)], | |
| 'image_samples': [[None for _ in range(num_shifts)] for _ in range(B)], | |
| 'latent_samples': [[None for _ in range(num_shifts)] for _ in range(B)], | |
| } | |
| # Process shifts in mini-batches | |
| num_batches = (num_shifts + tta_batch_size - 1) // tta_batch_size | |
| for batch_idx in range(num_batches): | |
| # Check for shutdown request | |
| if self._shutdown_requested: | |
| print(f"\n⚠️ TTA fundamentals computation interrupted at batch {batch_idx}/{num_batches}. Shutting down...") | |
| break | |
| batch_start = batch_idx * tta_batch_size | |
| batch_end = min(batch_start + tta_batch_size, num_shifts) | |
| batch_shifts = shifts[batch_start:batch_end] | |
| print(f" Batch {batch_idx+1}/{num_batches}: processing shifts {batch_start}-{batch_end-1}") | |
| # Create combined batch: [B × len(batch_shifts), 3, H, W] | |
| combined_crops = [] | |
| metadata = [] # Track (img_idx, shift_idx_in_full_list, shift, position) | |
| for img_idx in range(B): | |
| for shift_local_idx, (dx, dy) in enumerate(batch_shifts): | |
| shift_global_idx = batch_start + shift_local_idx | |
| y0 = offset + dy | |
| x0 = offset + dx | |
| y1 = y0 + img_size | |
| x1 = x0 + img_size | |
| # Skip invalid crops | |
| if not (0 <= y0 < y1 <= aug_size and 0 <= x0 < x1 <= aug_size): | |
| continue | |
| crop = x_up[img_idx:img_idx+1, :, y0:y1, x0:x1] | |
| combined_crops.append(crop) | |
| metadata.append({ | |
| 'img_idx': img_idx, | |
| 'shift_idx': shift_global_idx, | |
| 'shift': (dx, dy), | |
| 'position': (x0, y0, x1, y1) | |
| }) | |
| if not combined_crops: | |
| continue | |
| # Stack: [B*tta_batch_size, 3, H, W] | |
| combined_tensor = torch.cat(combined_crops, dim=0) | |
| # Expand context tensor to match combined batch size | |
| combined_batch_size = combined_tensor.shape[0] | |
| if 'context' in model_kwargs and model_kwargs['context'] is not None: | |
| original_context = model_kwargs['context'] | |
| # Validate original context | |
| if original_context is None: | |
| raise ValueError(f"original_context is None in batch {batch_idx}") | |
| if not isinstance(original_context, torch.Tensor): | |
| raise TypeError(f"original_context must be Tensor, got {type(original_context)}") | |
| # Build expanded context based on metadata | |
| expanded_context_list = [] | |
| for meta in metadata: | |
| img_idx = meta['img_idx'] | |
| if img_idx >= original_context.shape[0]: | |
| raise IndexError(f"img_idx {img_idx} >= original_context.shape[0] {original_context.shape[0]}") | |
| img_context = original_context[img_idx:img_idx+1] # [1, 1] | |
| expanded_context_list.append(img_context) | |
| if expanded_context_list: | |
| expanded_context = torch.cat(expanded_context_list, dim=0) | |
| else: | |
| num_shifts_in_batch = len(batch_shifts) | |
| expanded_context = original_context.repeat_interleave(num_shifts_in_batch, dim=0) | |
| if expanded_context.shape[0] != combined_batch_size: | |
| expanded_context = expanded_context[:combined_batch_size] | |
| # Validate expanded context | |
| if expanded_context is None: | |
| raise ValueError(f"expanded_context is None in batch {batch_idx}") | |
| if expanded_context.shape[0] != combined_batch_size: | |
| raise ValueError(f"expanded_context.shape[0] {expanded_context.shape[0]} != combined_batch_size {combined_batch_size}") | |
| model_kwargs_batch = { | |
| 'context': expanded_context, | |
| 'mask': model_kwargs.get('mask', None) | |
| } | |
| else: | |
| raise ValueError(f"model_kwargs['context'] is missing or None in batch {batch_idx}") | |
| # Process entire combined batch at once | |
| with torch.no_grad(): | |
| encoded = self.vae.encode(combined_tensor).latent_dist.mean.mul_(0.18215) | |
| latent_samples = diffusion.ddim_deviation_sample_loop( | |
| self.model, encoded.shape, noise=encoded, | |
| clip_denoised=False, start_t=reverse_steps, | |
| model_kwargs=model_kwargs_batch, progress=False, | |
| device=device, eta=eta | |
| ) | |
| image_samples = self.vae.decode(latent_samples / 0.18215).sample | |
| x0_decoded = self.vae.decode(encoded / 0.18215).sample | |
| # Store fundamentals per shift (NEW - this is the key difference!) | |
| for i, meta in enumerate(metadata): | |
| img_idx = meta['img_idx'] | |
| shift_idx = meta['shift_idx'] | |
| # Store each fundamental as [1, C, H, W] (single image) | |
| all_fundamentals['x0'][img_idx][shift_idx] = x0_decoded[i:i+1].cpu().numpy() | |
| all_fundamentals['encoded'][img_idx][shift_idx] = encoded[i:i+1].cpu().numpy() | |
| all_fundamentals['image_samples'][img_idx][shift_idx] = image_samples[i:i+1].cpu().numpy() | |
| all_fundamentals['latent_samples'][img_idx][shift_idx] = latent_samples[i:i+1].cpu().numpy() | |
| # Return per-shift fundamentals with metadata | |
| return { | |
| 'fundamentals': all_fundamentals, | |
| 'shifts': shifts, | |
| 'num_shifts': num_shifts, | |
| 'directions': directions, | |
| 'pad_px': pad_px, | |
| 'stride': stride | |
| } | |
| def _evaluate_full_image_tiled(self, args, enable_resume=False): | |
| """ | |
| Evaluate full images using tiled approach. | |
| This method handles large images by tiling, reconstructing, and stitching. | |
| """ | |
| # Initialize diffusion | |
| diffusion = create_diffusion(str(args.reverse_steps)) | |
| # Get categories (handle both single category and list) | |
| categories = getattr(args, 'categories', [args.object_category]) | |
| if isinstance(categories, str): | |
| categories = [categories] | |
| # Storage for all results across categories | |
| all_category_anomaly_maps = {} | |
| all_category_masks = {} | |
| all_category_labels = {} | |
| # Simple plot save function (can be extended) | |
| def queue_plot_save_batch_or_individual(results_list, img_paths, base_filepath, | |
| anomaly_threshold=None, min_area=None, | |
| use_batch_plot=True): | |
| # Placeholder - can be implemented if needed | |
| # Use instance variables if not provided | |
| if anomaly_threshold is None: | |
| anomaly_threshold = self.anomaly_threshold | |
| if min_area is None: | |
| min_area = self.anomaly_min_area | |
| pass | |
| queue_plot_save_func = queue_plot_save_batch_or_individual if self.save_plot_path else None | |
| for category in categories: | |
| # Check for shutdown request | |
| if self._shutdown_requested: | |
| print(f"\n⚠️ Full image evaluation interrupted at category {category}. Shutting down...") | |
| break | |
| self.log_message(f"\n{'='*80}") | |
| self.log_message(f"Evaluating category: {category}") | |
| self.log_message(f"{'='*80}\n") | |
| # Setup resume mode | |
| json_output_path = None | |
| existing_json_data = None | |
| processed_image_paths = set() | |
| json_output_path, existing_json_data, processed_image_paths, self.save_plot_path = \ | |
| setup_resume_mode( | |
| enable_resume=enable_resume, | |
| category=category, | |
| save_plot_path=self.save_plot_path, | |
| data_dir=args.data_dir, | |
| json_output_path=None | |
| ) | |
| # Create temporal dataset manager | |
| test_dataset_manager = create_temporal_dataset_manager( | |
| dataset_name=args.dataset, | |
| mode='test', | |
| object_class=category, | |
| rootdir=args.data_dir, | |
| image_size=getattr(args, 'dataset_image_size', args.image_size), | |
| crop_size=args.crop_size, | |
| anomaly_class=getattr(args, 'anomaly_class', None), | |
| csv_split_file=getattr(args, 'csv_split_file', None) | |
| ) | |
| # Create temporal dataloader | |
| test_loader = create_temporal_dataloader( | |
| dataset_manager=test_dataset_manager, | |
| epoch=0, | |
| batch_size=args.batch_size, | |
| shuffle=False, | |
| num_workers=args.num_workers, | |
| drop_last=False | |
| ) | |
| # Initialize storage for this category | |
| anomaly_maps_padded = { | |
| 'anomaly_geometric': [], | |
| 'anomaly_arithmetic': [], | |
| 'latent_discrepancy': [], | |
| 'image_discrepancy': [] | |
| } | |
| all_json_outputs = [] | |
| # Process full image evaluation loop (save_queue and save_thread are shared across categories) | |
| anomaly_maps_padded, all_json_outputs, existing_json_data, processed_image_paths = \ | |
| process_full_image_evaluation_loop( | |
| test_loader=test_loader, | |
| test_dataset_manager=test_dataset_manager, | |
| device=self.device, | |
| vae=self.vae, | |
| model=self.model, | |
| amp_dtype=self.amp_dtype, | |
| diffusion=diffusion, | |
| args=args, | |
| category=category, | |
| data_dir=args.data_dir, | |
| calculate_anomaly_maps_batch_func=lambda x0, encoded, image_samples, latent_samples, crop_size: \ | |
| self.calculate_anomaly_maps_batch( | |
| x0, encoded, image_samples, latent_samples, crop_size, | |
| image_diff_clip_max=self.image_diff_clip_max, | |
| latent_diff_clip_max=self.latent_diff_clip_max | |
| ), | |
| save_plot_path=self.save_plot_path, | |
| save_queue=self.save_queue, | |
| queue_plot_save_batch_or_individual_func=queue_plot_save_func, | |
| enable_resume=enable_resume, | |
| json_output_path=json_output_path, | |
| existing_json_data=existing_json_data, | |
| processed_image_paths=processed_image_paths, | |
| all_json_outputs=all_json_outputs, | |
| anomaly_maps_padded=anomaly_maps_padded | |
| ) | |
| # Stack anomaly maps | |
| anomaly_maps = stack_anomaly_maps(anomaly_maps_padded) | |
| # Store results for this category | |
| all_category_anomaly_maps[category] = anomaly_maps | |
| # Get masks if available (for metrics calculation) | |
| # Note: In full_image_eval mode, masks might not be available | |
| # This is a simplified version - can be extended if needed | |
| all_category_masks[category] = None | |
| all_category_labels[category] = None | |
| self.log_message(f"Completed evaluation for category: {category}") | |
| self.log_message(f"{'='*80}\n") | |
| # Wait for all images to be saved (after all categories) | |
| if self.save_plot_path and self.save_queue and not self._disable_file_saving: | |
| self.wait_for_saves() | |
| # Return results in a format compatible with standard evaluation | |
| # Combine all categories if multiple | |
| if len(categories) == 1: | |
| return all_category_anomaly_maps[categories[0]], all_category_masks[categories[0]], all_category_labels[categories[0]] | |
| else: | |
| # Combine all categories | |
| combined_anomaly_maps = {} | |
| for key in ['anomaly_geometric', 'anomaly_arithmetic', 'latent_discrepancy', 'image_discrepancy']: | |
| combined_anomaly_maps[key] = np.concatenate([ | |
| all_category_anomaly_maps[cat][key] for cat in categories | |
| if key in all_category_anomaly_maps[cat] | |
| ], axis=0) | |
| return combined_anomaly_maps, None, None | |
| def _setup_evaluation(self, args): | |
| """Setup dataset, dataloader, and diffusion for evaluation. | |
| Extracted from evaluate() lines 970-1012 to make it composable and testable. | |
| Args: | |
| args: Configuration arguments | |
| Returns: | |
| Tuple[Dataset, DataLoader, Diffusion]: test_dataset, test_loader, diffusion | |
| """ | |
| test_dataset_manager = create_temporal_dataset_manager( | |
| dataset_name=args.dataset, | |
| mode='test', | |
| object_class=args.object_category, | |
| rootdir=args.data_dir, | |
| image_size=args.image_size, | |
| crop_size=args.crop_size, | |
| anomaly_class=getattr(args, 'anomaly_class', None), | |
| csv_split_file=getattr(args, 'csv_split_file', None) | |
| ) | |
| test_loader = create_temporal_dataloader( | |
| dataset_manager=test_dataset_manager, | |
| epoch=0, | |
| batch_size=args.batch_size, | |
| shuffle=False, # No shuffle for validation | |
| num_workers=args.num_workers, | |
| drop_last=False | |
| ) | |
| # Initialize diffusion | |
| diffusion = create_diffusion( | |
| f'ddim{args.reverse_steps}', | |
| predict_deviation=True, | |
| sigma_small=False, | |
| predict_xstart=False, | |
| diffusion_steps=10 | |
| ) | |
| return test_dataset_manager.current_dataset, test_loader, diffusion | |
| def _run_inference_iter(self, args, test_dataset, test_loader, diffusion): | |
| """Iterator yielding fundamentals from inference. | |
| Yields: (x, mask, object_cls, x0, encoded, image_samples, latent_samples, paths) | |
| """ | |
| from tqdm import tqdm | |
| import time | |
| for batch_idx, (x, mask, object_cls) in enumerate(tqdm(test_loader, desc="Detection")): | |
| batch_iter_start = time.time() | |
| # Check for shutdown request | |
| if self._shutdown_requested: | |
| print(f"\n⚠️ Inference interrupted at batch {batch_idx}. Shutting down...") | |
| break | |
| data_load_time = time.time() - batch_iter_start | |
| if self.debug_enabled: | |
| self.log_message(f" [TIMING] Batch {batch_idx} data load: {data_load_time:.3f}s", level='debug') | |
| x = x.to(self.device) | |
| # Get image paths | |
| batch_size = x.shape[0] | |
| batch_image_paths = [] | |
| for b in range(batch_size): | |
| idx = batch_idx * args.batch_size + b | |
| if idx < len(test_dataset): | |
| if hasattr(test_dataset, 'image_paths'): | |
| batch_image_paths.append(test_dataset.image_paths[idx]) | |
| elif hasattr(test_dataset, 'data_list'): | |
| batch_image_paths.append(test_dataset.data_list[idx]) | |
| elif hasattr(test_dataset, 'data_df'): | |
| # Extract path from DataFrame (column 'image' contains the relative path) | |
| batch_image_paths.append(test_dataset.data_df.iloc[idx]['image']) | |
| else: | |
| batch_image_paths.append(f"image_{idx:04d}") | |
| # Prepare model kwargs with proper object_cls tensor formatting | |
| # Model expects context tensor with shape [batch_size, 1] | |
| if isinstance(object_cls, torch.Tensor): | |
| object_cls_tensor = object_cls.to(self.device) | |
| # Ensure shape is [batch_size, 1] for embedding | |
| if len(object_cls_tensor.shape) == 1: | |
| object_cls_tensor = object_cls_tensor.unsqueeze(1) | |
| elif len(object_cls_tensor.shape) == 0: | |
| # Scalar tensor, expand to [batch_size, 1] | |
| object_cls_tensor = object_cls_tensor.unsqueeze(0).unsqueeze(0).expand(x.shape[0], 1) | |
| else: | |
| # If it's a single int or list, convert to tensor [batch_size, 1] | |
| if isinstance(object_cls, (list, tuple)): | |
| object_cls_tensor = torch.tensor(object_cls, device=self.device, dtype=torch.long) | |
| else: | |
| object_cls_tensor = torch.tensor([object_cls] * x.shape[0], device=self.device, dtype=torch.long) | |
| # Ensure shape is [batch_size, 1] | |
| if len(object_cls_tensor.shape) == 1: | |
| object_cls_tensor = object_cls_tensor.unsqueeze(1) | |
| # Ensure dtype is long for embedding indices | |
| if object_cls_tensor.dtype != torch.long: | |
| object_cls_tensor = object_cls_tensor.long() | |
| # Prepare model kwargs (needed for both TTA and non-TTA) | |
| model_kwargs = {'context': object_cls_tensor, 'mask': None} | |
| # Check TTA mode - compute per-shift fundamentals for saving | |
| if args.pad_px is not None: | |
| # Compute TTA fundamentals per shift | |
| tta_result = self._compute_tta_per_shift_fundamentals( | |
| x, | |
| model_kwargs=model_kwargs, | |
| diffusion=diffusion, | |
| reverse_steps=args.reverse_steps, | |
| pad_px=args.pad_px, | |
| img_enlarge_px=getattr(args, 'img_enlarge_px', 4), | |
| stride=getattr(args, 'stride', 1), | |
| directions=getattr(args, 'directions', ('h', 'v', 'diag')), | |
| shift_method=getattr(args, 'shift_method', 'interpolate'), | |
| tta_batch_size=getattr(args, 'tta_batch_size', 8), | |
| eta=0.0, | |
| crop_size=args.crop_size, | |
| ) | |
| # Extract fundamentals and metadata | |
| fundamentals_dict = tta_result['fundamentals'] | |
| tta_metadata = { | |
| 'num_shifts': tta_result['num_shifts'], | |
| 'shifts': tta_result['shifts'], | |
| 'directions': tta_result['directions'], | |
| 'pad_px': tta_result['pad_px'], | |
| 'stride': tta_result['stride'], | |
| } | |
| # Yield 6-tuple for TTA mode (will be saved by _save_fundamentals_iter) | |
| yield (x, mask, object_cls, fundamentals_dict, batch_image_paths, tta_metadata) | |
| continue | |
| # Standard inference (no TTA) | |
| with torch.no_grad(): | |
| encoded = self.vae.encode(x).latent_dist.mean.mul_(0.18215) | |
| latent_samples = diffusion.ddim_deviation_sample_loop( | |
| self.model, | |
| x.shape, | |
| noise=encoded, | |
| clip_denoised=False, | |
| denoised_fn=None, | |
| model_kwargs=model_kwargs, | |
| device=self.device, | |
| progress=False, | |
| eta=0.0, | |
| ) | |
| image_samples = self.vae.decode(latent_samples / 0.18215).sample | |
| x0 = self.vae.decode(encoded / 0.18215).sample | |
| # Yield fundamentals (8-tuple) | |
| yield (x, mask, object_cls, x0, encoded, image_samples, latent_samples, batch_image_paths) | |
| def _save_fundamentals_iter(self, fundamental_iter, save_dir, args): | |
| """Save fundamentals as NPY files with single CSV manifest and pass through. | |
| Args: | |
| fundamental_iter: Iterator yielding either: | |
| - TTA: (x, mask, object_cls, fundamentals_dict, paths, tta_metadata) | |
| - Non-TTA: (x, mask, object_cls, x0, encoded, image_samples, latent_samples, paths) | |
| save_dir: Directory to save NPY files and CSV manifest | |
| args: Configuration arguments | |
| Yields: Same as input (pass-through) | |
| """ | |
| import os | |
| import numpy as np | |
| import csv | |
| os.makedirs(save_dir, exist_ok=True) | |
| # Check if consolidated NPZ format is requested | |
| use_consolidated_npz = getattr(args, 'use_consolidated_npz', False) | |
| # Check for existing CSV to resume from checkpoint | |
| csv_path = os.path.join(save_dir, 'fundamentals_manifest.csv') | |
| resume_from_batch = -1 | |
| global_img_idx = 0 # Global image counter across all batches | |
| if os.path.exists(csv_path): | |
| # Read last batch_idx from existing CSV and count processed images | |
| try: | |
| with open(csv_path, 'r', newline='', encoding='utf-8') as f: | |
| reader = csv.DictReader(f) | |
| last_row = None | |
| seen_images = set() # Track unique (batch_idx, image_idx_in_batch) pairs | |
| for row in reader: | |
| last_row = row | |
| batch_idx = int(row['batch_idx']) | |
| img_idx = int(row['image_idx_in_batch']) | |
| seen_images.add((batch_idx, img_idx)) | |
| if last_row is not None: | |
| resume_from_batch = int(last_row['batch_idx']) | |
| global_img_idx = len(seen_images) # Count unique images processed | |
| print(f"Resuming from batch {resume_from_batch + 1} (checkpoint found, {global_img_idx} images processed)") | |
| except Exception as e: | |
| print(f"Warning: Could not read checkpoint from {csv_path}: {e}") | |
| print("Starting from scratch") | |
| resume_from_batch = -1 | |
| global_img_idx = 0 | |
| # Open CSV manifest in appropriate mode | |
| if resume_from_batch >= 0: | |
| # Append mode - resume from checkpoint | |
| csv_file = open(csv_path, 'a', newline='', encoding='utf-8') | |
| csv_writer = csv.writer(csv_file) | |
| else: | |
| # Write mode - start fresh | |
| csv_file = open(csv_path, 'w', newline='', encoding='utf-8') | |
| csv_writer = csv.writer(csv_file) | |
| # Write header | |
| csv_writer.writerow([ | |
| 'batch_idx', 'image_idx_in_batch', 'original_file_path', 'full_file_path', | |
| 'fundamental_type', 'is_tta', 'total_shifts', 'shift_idx', | |
| 'shift_type', 'shift_amount', 'npy_filename' | |
| ]) | |
| try: | |
| for batch_idx, batch_data in enumerate(fundamental_iter): | |
| # Check for shutdown request | |
| if self._shutdown_requested: | |
| print(f"\n⚠️ Saving fundamentals interrupted at batch {batch_idx}. Shutting down...") | |
| break | |
| # Skip batches that are already processed (checkpoint resume) | |
| if batch_idx <= resume_from_batch: | |
| # Pass through without processing | |
| if len(batch_data) == 6: # TTA mode | |
| yield batch_data | |
| else: # Non-TTA mode | |
| yield batch_data | |
| continue | |
| # Detect TTA vs non-TTA by tuple length | |
| if len(batch_data) == 6: # TTA mode | |
| x, mask, object_cls, fundamentals_dict, paths, tta_metadata = batch_data | |
| is_tta = True | |
| else: # Non-TTA mode (8-tuple) | |
| x, mask, object_cls, x0, encoded, image_samples, latent_samples, paths = batch_data | |
| is_tta = False | |
| if is_tta: | |
| # Save per-shift fundamentals | |
| num_shifts = tta_metadata['num_shifts'] | |
| shifts = tta_metadata['shifts'] | |
| for img_idx_in_batch, path in enumerate(paths): | |
| if use_consolidated_npz: | |
| # CONSOLIDATED NPZ FORMAT: Save all fundamentals in one .npz file | |
| original_image = x[img_idx_in_batch].cpu().numpy() | |
| # Collect all shift data | |
| consolidated_data = { | |
| 'original_image': original_image, | |
| 'num_shifts': num_shifts, | |
| 'shifts': np.array(shifts) # List of (dx, dy) tuples | |
| } | |
| # Add each fundamental type for all shifts | |
| for fund_type in ['x0', 'encoded', 'image_samples', 'latent_samples']: | |
| # Stack all shifts into single array [num_shifts, ...] | |
| shift_arrays = [] | |
| for shift_idx in range(num_shifts): | |
| fund_array = fundamentals_dict[fund_type][img_idx_in_batch][shift_idx] | |
| if fund_array is not None: | |
| shift_arrays.append(fund_array) | |
| if shift_arrays: | |
| consolidated_data[fund_type] = np.stack(shift_arrays, axis=0) | |
| # Save single compressed .npz file | |
| npz_filename = f"{global_img_idx:06d}_fundamentals.npz" | |
| np.savez_compressed( | |
| os.path.join(save_dir, npz_filename), | |
| **consolidated_data | |
| ) | |
| # Write single CSV row for consolidated file | |
| csv_writer.writerow([ | |
| batch_idx, | |
| img_idx_in_batch, | |
| os.path.splitext(os.path.basename(path))[0], # original_file_path | |
| path, # full_file_path | |
| 'consolidated', # fundamental_type | |
| 1, # is_tta | |
| num_shifts, | |
| -1, # shift_idx (-1 for consolidated) | |
| 'consolidated', # shift_type | |
| 0, # shift_amount | |
| npz_filename | |
| ]) | |
| else: | |
| # ORIGINAL MULTI-FILE FORMAT: Separate .npy per fundamental | |
| # Save original image (once per image, not per shift) | |
| original_image = x[img_idx_in_batch].cpu().numpy() | |
| original_npy_filename = f"{global_img_idx:06d}_original_image.npy" | |
| np.save(os.path.join(save_dir, original_npy_filename), original_image) | |
| # Write CSV row for original image | |
| csv_writer.writerow([ | |
| batch_idx, | |
| img_idx_in_batch, | |
| os.path.splitext(os.path.basename(path))[0], # original_file_path (without extension) | |
| path, # full_file_path (with extension) | |
| 'original_image', | |
| 1, # is_tta | |
| num_shifts, | |
| -1, # shift_idx (-1 means original, not shifted) | |
| 'original', # shift_type | |
| 0, # shift_amount | |
| original_npy_filename | |
| ]) | |
| for shift_idx, (dx, dy) in enumerate(shifts): | |
| # Determine shift type and amount | |
| shift_type, shift_amount = _get_shift_type_and_amount(dx, dy) | |
| for fund_type in ['x0', 'encoded', 'image_samples', 'latent_samples']: | |
| fund_array = fundamentals_dict[fund_type][img_idx_in_batch][shift_idx] | |
| if fund_array is None: | |
| continue # Skip invalid shifts | |
| # Save NPY file | |
| npy_filename = f"{global_img_idx:06d}_shift_{shift_idx:02d}_{fund_type}.npy" | |
| np.save(os.path.join(save_dir, npy_filename), fund_array) | |
| # Write CSV row | |
| csv_writer.writerow([ | |
| batch_idx, | |
| img_idx_in_batch, | |
| os.path.splitext(os.path.basename(path))[0], # original_file_path (without extension) | |
| path, # full_file_path (with extension) | |
| fund_type, | |
| 1, # is_tta | |
| num_shifts, | |
| shift_idx, | |
| shift_type, | |
| shift_amount, | |
| npy_filename | |
| ]) | |
| # Increment global image counter after processing each image | |
| global_img_idx += 1 | |
| # Flush CSV to ensure incremental writes | |
| csv_file.flush() | |
| # Pass through (6-tuple) | |
| yield (x, mask, object_cls, fundamentals_dict, paths, tta_metadata) | |
| else: | |
| # Non-TTA mode - save per-image files | |
| # Check if fundamentals are None (happens when TTA is enabled but yields 8-tuple with None values) | |
| if x0 is None: | |
| # TTA mode with dummy 8-tuple - skip saving and just pass through | |
| yield (x, mask, object_cls, x0, encoded, image_samples, latent_samples, paths) | |
| continue | |
| # Save per-image fundamentals for non-TTA mode | |
| for img_idx_in_batch, path in enumerate(paths): | |
| # Extract single image from batch | |
| img_x = x[img_idx_in_batch:img_idx_in_batch+1].cpu().numpy() | |
| img_x0 = x0[img_idx_in_batch:img_idx_in_batch+1].cpu().numpy() | |
| img_encoded = encoded[img_idx_in_batch:img_idx_in_batch+1].cpu().numpy() | |
| img_samples = image_samples[img_idx_in_batch:img_idx_in_batch+1].cpu().numpy() | |
| img_latents = latent_samples[img_idx_in_batch:img_idx_in_batch+1].cpu().numpy() | |
| if use_consolidated_npz: | |
| # CONSOLIDATED NPZ FORMAT: Save all fundamentals in one .npz file | |
| consolidated_data = { | |
| 'original_image': img_x, | |
| 'x0': img_x0, | |
| 'encoded': img_encoded, | |
| 'image_samples': img_samples, | |
| 'latent_samples': img_latents | |
| } | |
| # Save single compressed .npz file | |
| npz_filename = f"{global_img_idx:06d}_fundamentals.npz" | |
| np.savez_compressed( | |
| os.path.join(save_dir, npz_filename), | |
| **consolidated_data | |
| ) | |
| # Write single CSV row for consolidated file | |
| csv_writer.writerow([ | |
| batch_idx, | |
| img_idx_in_batch, | |
| os.path.splitext(os.path.basename(path))[0], # original_file_path | |
| path, # full_file_path | |
| 'consolidated', # fundamental_type | |
| 0, # is_tta | |
| 0, # total_shifts | |
| -1, # shift_idx (-1 for consolidated) | |
| 'consolidated', # shift_type | |
| 0, # shift_amount | |
| npz_filename | |
| ]) | |
| else: | |
| # ORIGINAL MULTI-FILE FORMAT: Separate .npy per fundamental | |
| # Save NPY files with global image index | |
| np.save(os.path.join(save_dir, f"{global_img_idx:06d}_original_image.npy"), img_x) | |
| np.save(os.path.join(save_dir, f"{global_img_idx:06d}_x0.npy"), img_x0) | |
| np.save(os.path.join(save_dir, f"{global_img_idx:06d}_encoded.npy"), img_encoded) | |
| np.save(os.path.join(save_dir, f"{global_img_idx:06d}_image_samples.npy"), img_samples) | |
| np.save(os.path.join(save_dir, f"{global_img_idx:06d}_latent_samples.npy"), img_latents) | |
| # Write CSV rows for non-TTA (one row per fundamental type per image) | |
| for fund_type in ['original_image', 'x0', 'encoded', 'image_samples', 'latent_samples']: | |
| npy_filename = f"{global_img_idx:06d}_{fund_type}.npy" | |
| csv_writer.writerow([ | |
| batch_idx, | |
| img_idx_in_batch, | |
| os.path.splitext(os.path.basename(path))[0], # original_file_path (without extension) | |
| path, # full_file_path (with extension) | |
| fund_type, | |
| 0, # is_tta | |
| 0, # total_shifts | |
| 0, # shift_idx | |
| 'none', # shift_type | |
| 0, # shift_amount | |
| npy_filename | |
| ]) | |
| # Increment global image counter after processing each image | |
| global_img_idx += 1 | |
| # Flush CSV to ensure incremental writes | |
| csv_file.flush() | |
| # Pass through (8-tuple) | |
| yield (x, mask, object_cls, x0, encoded, image_samples, latent_samples, paths) | |
| finally: | |
| # Always close CSV file | |
| csv_file.close() | |
| print(f"Saved fundamentals manifest: {csv_path}") | |
| def _load_fundamentals_iter(self, save_dir, test_dataset, args): | |
| """Load fundamentals from NPY files, handling both TTA and non-TTA. | |
| Args: | |
| save_dir: Directory containing NPY files and CSV manifest | |
| test_dataset: Dataset (used for getting masks) | |
| args: Command-line arguments | |
| Yields: Either: | |
| - TTA: (None, mask, None, fundamentals_dict, paths, tta_metadata) | |
| - Non-TTA: (None, mask, None, x0, encoded, image_samples, latent_samples, paths) | |
| """ | |
| import os | |
| import csv | |
| import numpy as np | |
| import torch | |
| from tqdm import tqdm | |
| from collections import defaultdict | |
| # Read CSV manifest | |
| csv_path = os.path.join(save_dir, 'fundamentals_manifest.csv') | |
| if not os.path.exists(csv_path): | |
| raise FileNotFoundError(f"CSV manifest not found: {csv_path}") | |
| # Parse CSV and group by batch_idx | |
| batches = defaultdict(lambda: { | |
| 'is_tta': None, | |
| 'paths': [], | |
| 'rows': [] # Store all CSV rows for this batch | |
| }) | |
| with open(csv_path, 'r', encoding='utf-8') as f: | |
| reader = csv.DictReader(f) | |
| for row in reader: | |
| batch_idx = int(row['batch_idx']) | |
| is_tta = int(row['is_tta']) == 1 | |
| # Set is_tta for this batch (should be consistent) | |
| if batches[batch_idx]['is_tta'] is None: | |
| batches[batch_idx]['is_tta'] = is_tta | |
| # Track unique paths - FIX: use full_file_path instead of original_file_path | |
| # to avoid collisions when different subdirectories have same filename | |
| # (e.g., wood/test/color/000.png and wood/test/combined/000.png both map to "000") | |
| path = row['full_file_path'] | |
| if path not in batches[batch_idx]['paths']: | |
| batches[batch_idx]['paths'].append(path) | |
| # Store row for later processing | |
| batches[batch_idx]['rows'].append(row) | |
| # Load fundamentals for each batch | |
| for batch_idx in tqdm(sorted(batches.keys()), desc="Loading fundamentals"): | |
| # Check for shutdown request | |
| if self._shutdown_requested: | |
| print(f"\n⚠️ Loading fundamentals interrupted at batch {batch_idx}. Shutting down...") | |
| break | |
| batch_info = batches[batch_idx] | |
| paths = batch_info['paths'] | |
| is_tta = batch_info['is_tta'] | |
| # Load ground truth masks from dataset | |
| batch_size = len(paths) | |
| masks_batch = [] | |
| # FIX: Use fixed batch size from args for index calculation | |
| # Same bug as in _load_fundamentals_iter_optimized | |
| fixed_batch_size = args.batch_size | |
| for b in range(batch_size): | |
| dataset_idx = batch_idx * fixed_batch_size + b | |
| if dataset_idx < len(test_dataset): | |
| _, mask, _ = test_dataset[dataset_idx] | |
| masks_batch.append(mask) | |
| # Skip this batch if we don't have ground truth masks | |
| if not masks_batch: | |
| if self.debug_enabled: | |
| self.log_message(f" [SKIP] Batch {batch_idx}: no ground truth data available", level='debug') | |
| continue | |
| mask = torch.stack(masks_batch) if len(masks_batch) > 1 else masks_batch[0].unsqueeze(0) | |
| # Filter paths to match available masks | |
| num_masks = len(masks_batch) | |
| if num_masks < len(paths): | |
| paths = paths[:num_masks] | |
| if is_tta: | |
| # Load per-shift fundamentals from CSV rows | |
| rows = batch_info['rows'] | |
| # Determine batch size and num_shifts (use num_masks not len(paths)) | |
| num_images = num_masks | |
| num_shifts = max(int(row['total_shifts']) for row in rows) | |
| # Initialize fundamentals storage | |
| # Note: original_image is saved in manifest but not loaded here (only used for visualization) | |
| fundamentals_dict = { | |
| 'x0': [[None for _ in range(num_shifts)] for _ in range(num_images)], | |
| 'encoded': [[None for _ in range(num_shifts)] for _ in range(num_images)], | |
| 'image_samples': [[None for _ in range(num_shifts)] for _ in range(num_images)], | |
| 'latent_samples': [[None for _ in range(num_shifts)] for _ in range(num_images)], | |
| } | |
| shifts = [] # Reconstruct shift list | |
| shift_coords_seen = {} # Track (dx, dy) -> shift_idx mapping | |
| # Load each file | |
| for row in rows: | |
| img_idx = int(row['image_idx_in_batch']) | |
| shift_idx = int(row['shift_idx']) | |
| fund_type = row['fundamental_type'] | |
| npy_filename = row['npy_filename'] | |
| # Skip images beyond available masks | |
| if img_idx >= num_masks: | |
| continue | |
| # Skip original_image during loading (it's only used for visualization, not evaluation) | |
| if fund_type == 'original_image': | |
| continue | |
| # Load NPY file | |
| npy_path = os.path.join(save_dir, npy_filename) | |
| fund_array = np.load(npy_path) | |
| # Store in fundamentals dict | |
| fundamentals_dict[fund_type][img_idx][shift_idx] = fund_array | |
| # Reconstruct shifts list (only once per shift_idx) | |
| if fund_type == 'x0' and img_idx == 0: # Use first image, first fundamental type | |
| shift_type = row['shift_type'] | |
| shift_amount = int(row['shift_amount']) | |
| # Convert back to (dx, dy) | |
| if shift_type == 'none': | |
| dx, dy = 0, 0 | |
| elif shift_type == 'h': | |
| dx, dy = shift_amount, 0 | |
| elif shift_type == 'v': | |
| dx, dy = 0, shift_amount | |
| elif shift_type == 'diag': | |
| # For diagonal, both dx and dy have same sign | |
| # Positive shift_amount -> both positive, negative -> both negative | |
| dx = dy = shift_amount | |
| # Only add if we haven't seen this shift_idx yet | |
| if shift_idx not in shift_coords_seen: | |
| shift_coords_seen[shift_idx] = (dx, dy) | |
| shifts.append((dx, dy)) | |
| # Sort shifts by their index to maintain order | |
| shifts_sorted = [shift_coords_seen[i] for i in sorted(shift_coords_seen.keys())] | |
| # Create TTA metadata | |
| tta_metadata = { | |
| 'shifts': shifts_sorted, | |
| 'num_shifts': num_shifts, | |
| 'directions': None, # Not stored in CSV | |
| 'pad_px': None, # Not stored in CSV | |
| 'stride': None # Not stored in CSV | |
| } | |
| # Yield TTA format (6-tuple) | |
| yield (None, mask, None, fundamentals_dict, paths, tta_metadata) | |
| else: | |
| # Non-TTA mode - load batch files | |
| x0 = torch.from_numpy(np.load(os.path.join(save_dir, f"{batch_idx:06d}_x0.npy"))) | |
| encoded = torch.from_numpy(np.load(os.path.join(save_dir, f"{batch_idx:06d}_encoded.npy"))) | |
| image_samples = torch.from_numpy(np.load(os.path.join(save_dir, f"{batch_idx:06d}_image_samples.npy"))) | |
| latent_samples = torch.from_numpy(np.load(os.path.join(save_dir, f"{batch_idx:06d}_latent_samples.npy"))) | |
| # Yield non-TTA format (8-tuple) | |
| yield (None, mask, None, x0, encoded, image_samples, latent_samples, paths) | |
| # ============================================================================ | |
| # PHASE 1: I/O PARALLELIZATION - ParallelNPYLoader | |
| # ============================================================================ | |
| class ParallelNPYLoader: | |
| """Parallel NPY file loader with memory mapping and prefetch. | |
| Features: | |
| - ThreadPoolExecutor for parallel file I/O | |
| - Memory-mapped loading for large files | |
| - LRU caching for frequently accessed files | |
| - Rolling prefetch queue | |
| """ | |
| def __init__(self, save_dir, num_workers=8, use_mmap=True, cache_size_mb=200): | |
| """Initialize parallel NPY loader. | |
| Args: | |
| save_dir: Directory containing NPY files | |
| num_workers: Number of parallel I/O workers | |
| use_mmap: Use memory-mapped mode for large files | |
| cache_size_mb: Maximum cache size in MB | |
| """ | |
| from concurrent.futures import ThreadPoolExecutor | |
| self.save_dir = save_dir | |
| self.num_workers = num_workers | |
| self.use_mmap = use_mmap | |
| self.cache_size_bytes = cache_size_mb * 1024 * 1024 | |
| self.executor = ThreadPoolExecutor(max_workers=num_workers) | |
| self.cache = {} # filepath -> numpy array | |
| self.cache_order = [] # LRU tracking | |
| self.cache_memory = 0 # Current cache size in bytes | |
| self.format_logged = False # Track if format has been logged | |
| def load_npy_file(self, filepath): | |
| """Load single NPY file with memory mapping and caching. | |
| Args: | |
| filepath: Absolute path to NPY file | |
| Returns: | |
| numpy.ndarray: Loaded array | |
| """ | |
| import numpy as np | |
| import os | |
| # Check cache first | |
| if filepath in self.cache: | |
| # Move to end (most recently used) | |
| self.cache_order.remove(filepath) | |
| self.cache_order.append(filepath) | |
| return self.cache[filepath] | |
| # Load from disk | |
| if not os.path.exists(filepath): | |
| raise FileNotFoundError(f"NPY file not found: {filepath}") | |
| if self.use_mmap: | |
| arr = np.load(filepath, mmap_mode='r') | |
| arr_size = arr.nbytes | |
| # Cache small files (< 50MB) in memory | |
| if arr_size < 50 * 1024 * 1024: | |
| arr = np.array(arr) # Force load into memory | |
| self._add_to_cache(filepath, arr, arr_size) | |
| else: | |
| arr = np.load(filepath) | |
| self._add_to_cache(filepath, arr, arr.nbytes) | |
| return arr | |
| def load_npz_consolidated(self, filepath): | |
| """Load consolidated .npz file containing all fundamentals. | |
| Args: | |
| filepath: Absolute path to .npz file | |
| Returns: | |
| dict: Dictionary with keys like 'x0', 'encoded', 'image_samples', 'latent_samples', etc. | |
| """ | |
| import numpy as np | |
| import os | |
| # Check cache first | |
| if filepath in self.cache: | |
| self.cache_order.remove(filepath) | |
| self.cache_order.append(filepath) | |
| return self.cache[filepath] | |
| # Load from disk | |
| if not os.path.exists(filepath): | |
| raise FileNotFoundError(f"NPZ file not found: {filepath}") | |
| # Load NPZ file (with memory mapping if enabled) | |
| if self.use_mmap: | |
| data = np.load(filepath, mmap_mode='r') | |
| else: | |
| data = np.load(filepath) | |
| # Extract all arrays into dictionary | |
| consolidated_data = {key: np.array(data[key]) for key in data.files} | |
| # Calculate total size | |
| total_size = sum(arr.nbytes if isinstance(arr, np.ndarray) else 0 | |
| for arr in consolidated_data.values()) | |
| # Cache if small enough (< 50MB) | |
| if total_size < 50 * 1024 * 1024: | |
| self._add_to_cache(filepath, consolidated_data, total_size) | |
| return consolidated_data | |
| def _add_to_cache(self, filepath, arr, arr_size): | |
| """Add array or dict of arrays to cache with LRU eviction. | |
| Args: | |
| filepath: File path (cache key) | |
| arr: Numpy array or dict of arrays to cache | |
| arr_size: Size in bytes (pre-calculated) | |
| """ | |
| import gc | |
| import numpy as np | |
| # Evict LRU items if needed | |
| evicted_count = 0 | |
| while (self.cache_memory + arr_size > self.cache_size_bytes and | |
| len(self.cache) > 0): | |
| lru_path = self.cache_order.pop(0) | |
| evicted_item = self.cache.pop(lru_path) | |
| # Calculate size based on type (array or dict of arrays) | |
| if isinstance(evicted_item, dict): | |
| evicted_size = sum(v.nbytes if isinstance(v, np.ndarray) else 0 | |
| for v in evicted_item.values()) | |
| else: | |
| evicted_size = evicted_item.nbytes | |
| self.cache_memory -= evicted_size | |
| del evicted_item | |
| evicted_count += 1 | |
| # Force garbage collection if we evicted items | |
| if evicted_count > 0: | |
| gc.collect() | |
| # Add to cache | |
| self.cache[filepath] = arr | |
| self.cache_order.append(filepath) | |
| self.cache_memory += arr_size | |
| def load_batch_fundamentals_parallel(self, batch_info, save_dir): | |
| """Load all fundamentals for a batch in parallel. | |
| Args: | |
| batch_info: Dict with 'rows' key containing CSV rows | |
| save_dir: Directory containing NPY files | |
| Returns: | |
| dict: Fundamentals dict with loaded arrays | |
| """ | |
| import os | |
| import numpy as np | |
| from concurrent.futures import as_completed | |
| rows = batch_info['rows'] | |
| is_tta = batch_info['is_tta'] | |
| if is_tta: | |
| # Check if using consolidated NPZ format | |
| first_row_fund_type = rows[0]['fundamental_type'] | |
| is_consolidated = (first_row_fund_type == 'consolidated') | |
| # Log format detection (once, on first batch) | |
| if not self.format_logged: | |
| self.format_logged = True | |
| format_str = "Consolidated NPZ (1 file/image)" if is_consolidated else "Multi-file NPY (33 files/image)" | |
| print(f" Detected format: {format_str}") | |
| # Determine dimensions | |
| num_images = len(batch_info['paths']) | |
| num_shifts = max(int(row['total_shifts']) for row in rows) | |
| # Initialize fundamentals storage | |
| fundamentals_dict = { | |
| 'x0': [[None for _ in range(num_shifts)] for _ in range(num_images)], | |
| 'encoded': [[None for _ in range(num_shifts)] for _ in range(num_images)], | |
| 'image_samples': [[None for _ in range(num_shifts)] for _ in range(num_images)], | |
| 'latent_samples': [[None for _ in range(num_shifts)] for _ in range(num_images)], | |
| } | |
| # For consolidated format, store shifts from NPZ file | |
| loaded_shifts = None | |
| if is_consolidated: | |
| # CONSOLIDATED NPZ FORMAT: Load one .npz per image | |
| futures = {} | |
| for row in rows: | |
| if row['fundamental_type'] != 'consolidated': | |
| continue | |
| npz_filename = row['npy_filename'] # Actually .npz | |
| npz_path = os.path.join(save_dir, npz_filename) | |
| # Submit load job for consolidated file | |
| future = self.executor.submit(self.load_npz_consolidated, npz_path) | |
| futures[future] = row | |
| # Collect results as they complete | |
| for future in as_completed(futures): | |
| row = futures[future] | |
| try: | |
| consolidated_data = future.result(timeout=30) | |
| # Extract image index | |
| img_idx = int(row['image_idx_in_batch']) | |
| # FIX: Extract shifts from first NPZ file (all images have same shifts) | |
| if loaded_shifts is None and 'shifts' in consolidated_data: | |
| loaded_shifts = consolidated_data['shifts'] # shape: (num_shifts, 2) | |
| # Populate fundamentals_dict from consolidated data | |
| for fund_type in ['x0', 'encoded', 'image_samples', 'latent_samples']: | |
| if fund_type in consolidated_data: | |
| fund_arrays = consolidated_data[fund_type] # [num_shifts, ...] | |
| # Assign each shift | |
| for shift_idx in range(len(fund_arrays)): | |
| if img_idx < len(fundamentals_dict[fund_type]): | |
| fundamentals_dict[fund_type][img_idx][shift_idx] = fund_arrays[shift_idx] | |
| except Exception as e: | |
| print(f"Warning: Failed to load consolidated {row['npy_filename']}: {e}") | |
| continue | |
| # FIX: Store loaded_shifts in fundamentals_dict for later use | |
| if loaded_shifts is not None: | |
| fundamentals_dict['_shifts'] = loaded_shifts | |
| else: | |
| # ORIGINAL MULTI-FILE FORMAT: Load individual .npy files | |
| # Submit all loads to thread pool | |
| futures = {} | |
| for row in rows: | |
| fund_type = row['fundamental_type'] | |
| # Skip original_image (not used for evaluation) | |
| if fund_type == 'original_image': | |
| continue | |
| npy_filename = row['npy_filename'] | |
| npy_path = os.path.join(save_dir, npy_filename) | |
| # Submit load job | |
| future = self.executor.submit(self.load_npy_file, npy_path) | |
| futures[future] = row | |
| # Collect results as they complete | |
| for future in as_completed(futures): | |
| row = futures[future] | |
| try: | |
| fund_array = future.result(timeout=30) # 30s timeout per file | |
| # Store in fundamentals dict | |
| img_idx = int(row['image_idx_in_batch']) | |
| shift_idx = int(row['shift_idx']) | |
| fund_type = row['fundamental_type'] | |
| # Bounds check: skip if img_idx is out of range | |
| # (happens when cache was created with different dataset filters) | |
| if img_idx >= len(fundamentals_dict[fund_type]): | |
| print(f"Warning: Skipping {row['npy_filename']}: img_idx {img_idx} out of range (batch has {len(fundamentals_dict[fund_type])} images)") | |
| continue | |
| fundamentals_dict[fund_type][img_idx][shift_idx] = fund_array | |
| except Exception as e: | |
| # Log error but continue | |
| print(f"Warning: Failed to load {row['npy_filename']}: {e}") | |
| continue | |
| return fundamentals_dict | |
| else: | |
| # Non-TTA mode | |
| # Check if using consolidated NPZ format | |
| first_row_fund_type = rows[0]['fundamental_type'] | |
| is_consolidated = (first_row_fund_type == 'consolidated') | |
| # Log format detection (once, on first batch) | |
| if not self.format_logged: | |
| self.format_logged = True | |
| format_str = "Consolidated NPZ (1 file/image)" if is_consolidated else "Multi-file NPY (4 files/image)" | |
| print(f" Detected format: {format_str}") | |
| if is_consolidated: | |
| # CONSOLIDATED NPZ FORMAT: Load single .npz file | |
| npz_filename = rows[0]['npy_filename'] | |
| npz_path = os.path.join(save_dir, npz_filename) | |
| try: | |
| consolidated_data = self.load_npz_consolidated(npz_path) | |
| # Extract fundamentals | |
| results = { | |
| 'x0': consolidated_data.get('x0'), | |
| 'encoded': consolidated_data.get('encoded'), | |
| 'image_samples': consolidated_data.get('image_samples'), | |
| 'latent_samples': consolidated_data.get('latent_samples') | |
| } | |
| return results | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to load consolidated NPZ {npz_filename}: {e}") | |
| else: | |
| # ORIGINAL MULTI-FILE FORMAT: Load 4 batch files in parallel | |
| batch_idx = int(rows[0]['batch_idx']) | |
| file_types = ['x0', 'encoded', 'image_samples', 'latent_samples'] | |
| # Submit loads | |
| futures = {} | |
| for ftype in file_types: | |
| filename = f"{batch_idx:06d}_{ftype}.npy" | |
| filepath = os.path.join(save_dir, filename) | |
| future = self.executor.submit(self.load_npy_file, filepath) | |
| futures[future] = ftype | |
| # Collect results | |
| results = {} | |
| for future in as_completed(futures): | |
| ftype = futures[future] | |
| try: | |
| arr = future.result(timeout=30) | |
| results[ftype] = arr | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to load {ftype}: {e}") | |
| return results | |
| def shutdown(self): | |
| """Shutdown thread pool executor.""" | |
| self.executor.shutdown(wait=True) | |
| self.cache.clear() | |
| self.cache_order.clear() | |
| self.cache_memory = 0 | |
| def _load_fundamentals_iter_optimized(self, save_dir, test_dataset, args): | |
| """Optimized version with parallel I/O and prefetching. | |
| This is the Phase 1 optimization that uses ParallelNPYLoader for 2-3× speedup. | |
| Args: | |
| save_dir: Directory containing NPY files and CSV manifest | |
| test_dataset: Dataset (used for getting masks) | |
| args: Command-line arguments with optimization flags | |
| Yields: Either: | |
| - TTA: (None, mask, None, fundamentals_dict, paths, tta_metadata) | |
| - Non-TTA: (None, mask, None, x0, encoded, image_samples, latent_samples, paths) | |
| """ | |
| import os | |
| import csv | |
| import numpy as np | |
| import torch | |
| from tqdm import tqdm | |
| from collections import defaultdict | |
| from concurrent.futures import as_completed | |
| # Get I/O workers and prefetch settings from args | |
| io_workers = getattr(args, 'io_workers', 16) # Increased default for SSDs | |
| use_mmap = getattr(args, 'use_mmap', True) | |
| prefetch_window = getattr(args, 'prefetch_window', 4) # Configurable prefetch | |
| # Log I/O optimization settings for verification | |
| print(f"\n{'='*70}") | |
| print(f" NPY Loader Configuration (evaluate_only mode)") | |
| print(f"{'='*70}") | |
| print(f" I/O Workers: {io_workers} threads") | |
| print(f" Prefetch Window: {prefetch_window} batches") | |
| print(f" Memory Mapping: {'Enabled' if use_mmap else 'Disabled'}") | |
| print(f" Cache Size: 200 MB") | |
| print(f"{'='*70}\n") | |
| # Initialize parallel loader | |
| parallel_loader = self.ParallelNPYLoader( | |
| save_dir=save_dir, | |
| num_workers=io_workers, | |
| use_mmap=use_mmap, | |
| cache_size_mb=200 | |
| ) | |
| try: | |
| # Read CSV manifest (same as original) | |
| csv_path = os.path.join(save_dir, 'fundamentals_manifest.csv') | |
| if not os.path.exists(csv_path): | |
| raise FileNotFoundError(f"CSV manifest not found: {csv_path}") | |
| # Parse CSV and group by batch_idx | |
| batches = defaultdict(lambda: { | |
| 'is_tta': None, | |
| 'paths': [], | |
| 'rows': [] | |
| }) | |
| with open(csv_path, 'r', encoding='utf-8') as f: | |
| reader = csv.DictReader(f) | |
| for row in reader: | |
| batch_idx = int(row['batch_idx']) | |
| is_tta = int(row['is_tta']) == 1 | |
| if batches[batch_idx]['is_tta'] is None: | |
| batches[batch_idx]['is_tta'] = is_tta | |
| # FIX: use full_file_path instead of original_file_path | |
| # to avoid collisions when different subdirectories have same filename | |
| path = row['full_file_path'] | |
| if path not in batches[batch_idx]['paths']: | |
| batches[batch_idx]['paths'].append(path) | |
| batches[batch_idx]['rows'].append(row) | |
| sorted_batch_ids = sorted(batches.keys()) | |
| # Prefetch system: Load batches in rolling window (configurable via --prefetch_window) | |
| # prefetch_window already set from args above | |
| prefetch_futures = {} | |
| # Prefetch first batch(es) | |
| for i in range(min(prefetch_window, len(sorted_batch_ids))): | |
| batch_idx = sorted_batch_ids[i] | |
| future = parallel_loader.executor.submit( | |
| parallel_loader.load_batch_fundamentals_parallel, | |
| batches[batch_idx], | |
| save_dir | |
| ) | |
| prefetch_futures[batch_idx] = future | |
| # Process batches with rolling prefetch | |
| for i, batch_idx in enumerate(tqdm(sorted_batch_ids, desc="Loading fundamentals (optimized)")): | |
| # Check for shutdown | |
| if self._shutdown_requested: | |
| print(f"\n⚠️ Loading fundamentals interrupted at batch {batch_idx}. Shutting down...") | |
| break | |
| batch_info = batches[batch_idx] | |
| paths = batch_info['paths'] | |
| is_tta = batch_info['is_tta'] | |
| # Get fundamentals (from prefetch or load now) | |
| if batch_idx in prefetch_futures: | |
| fundamentals_or_results = prefetch_futures[batch_idx].result(timeout=120) | |
| del prefetch_futures[batch_idx] | |
| else: | |
| # Not prefetched, load now | |
| fundamentals_or_results = parallel_loader.load_batch_fundamentals_parallel( | |
| batch_info, | |
| save_dir | |
| ) | |
| # Prefetch next batch (rolling window) | |
| next_idx = i + prefetch_window | |
| if next_idx < len(sorted_batch_ids): | |
| next_batch_idx = sorted_batch_ids[next_idx] | |
| future = parallel_loader.executor.submit( | |
| parallel_loader.load_batch_fundamentals_parallel, | |
| batches[next_batch_idx], | |
| save_dir | |
| ) | |
| prefetch_futures[next_batch_idx] = future | |
| # Load ground truth masks (same as original) | |
| batch_size = len(paths) | |
| masks_batch = [] | |
| # FIX: Use fixed batch size from args for index calculation | |
| # The bug was using variable batch_size (len(paths)) which fails for partial batches | |
| # E.g., if last batch has 9 items: dataset_idx = 7 * 9 = 63 (wrong!) | |
| # Should be: dataset_idx = 7 * 10 = 70 (correct, using fixed batch_size=10) | |
| fixed_batch_size = args.batch_size | |
| for b in range(batch_size): | |
| dataset_idx = batch_idx * fixed_batch_size + b | |
| if dataset_idx < len(test_dataset): | |
| _, mask, _ = test_dataset[dataset_idx] | |
| masks_batch.append(mask) | |
| # Skip this batch if we don't have ground truth masks | |
| if not masks_batch: | |
| if self.debug_enabled: | |
| self.log_message(f" [SKIP] Batch {batch_idx}: no ground truth data available", level='debug') | |
| continue | |
| mask = torch.stack(masks_batch) if len(masks_batch) > 1 else masks_batch[0].unsqueeze(0) | |
| # Filter fundamentals and paths to match available masks | |
| num_masks = len(masks_batch) | |
| if num_masks < len(paths): | |
| paths = paths[:num_masks] | |
| if is_tta: | |
| # Fundamentals already loaded by parallel loader | |
| fundamentals_dict = fundamentals_or_results | |
| # DEBUG: Check fundamentals structure and verify shifts are different | |
| if batch_idx == 0: | |
| print(f"\n[DEBUG] Batch {batch_idx} fundamentals_dict from CACHE:") | |
| for key in ['x0']: | |
| print(f" {key}: type={type(fundamentals_dict[key])}, len={len(fundamentals_dict[key])}") | |
| if len(fundamentals_dict[key]) > 0: | |
| first = fundamentals_dict[key][0] | |
| print(f" first element: type={type(first)}, len={len(first) if isinstance(first, list) else 'N/A'}") | |
| if isinstance(first, list) and len(first) > 1: | |
| if first[0] is not None and first[1] is not None: | |
| import numpy as np | |
| arr0 = first[0] if isinstance(first[0], np.ndarray) else first[0].cpu().numpy() | |
| arr1 = first[1] if isinstance(first[1], np.ndarray) else first[1].cpu().numpy() | |
| diff = np.abs(arr0 - arr1).sum() | |
| print(f" Shift 0 vs Shift 1 |diff|: {diff:.6f} (CRITICAL: must be >0!)") | |
| print(f" arr0: shape={arr0.shape}, mean={arr0.mean():.4f}") | |
| print(f" arr1: shape={arr1.shape}, mean={arr1.mean():.4f}") | |
| # Filter fundamentals to match available masks | |
| if num_masks < len(fundamentals_dict['x0']): | |
| for key in fundamentals_dict.keys(): | |
| fundamentals_dict[key] = fundamentals_dict[key][:num_masks] | |
| # Reconstruct shifts list | |
| rows = batch_info['rows'] | |
| num_shifts = max(int(row['total_shifts']) for row in rows) | |
| # FIX: For consolidated format, use shifts loaded from NPZ file | |
| if '_shifts' in fundamentals_dict: | |
| # Consolidated format: shifts are in NPZ file | |
| import numpy as np | |
| shifts_array = fundamentals_dict['_shifts'] # shape: (num_shifts, 2) | |
| shifts_sorted = [(int(shifts_array[i, 0]), int(shifts_array[i, 1])) for i in range(len(shifts_array))] | |
| else: | |
| # Multi-file format: reconstruct shifts from manifest | |
| shifts = [] | |
| shift_coords_seen = {} | |
| for row in rows: | |
| img_idx = int(row['image_idx_in_batch']) | |
| shift_idx = int(row['shift_idx']) | |
| fund_type = row['fundamental_type'] | |
| # Reconstruct shifts list (only once per shift_idx) | |
| if fund_type == 'x0' and img_idx == 0: | |
| shift_type = row['shift_type'] | |
| shift_amount = int(row['shift_amount']) | |
| # Convert back to (dx, dy) | |
| if shift_type == 'none': | |
| dx, dy = 0, 0 | |
| elif shift_type == 'h': | |
| dx, dy = shift_amount, 0 | |
| elif shift_type == 'v': | |
| dx, dy = 0, shift_amount | |
| elif shift_type == 'diag': | |
| dx = dy = shift_amount | |
| if shift_idx not in shift_coords_seen: | |
| shift_coords_seen[shift_idx] = (dx, dy) | |
| shifts.append((dx, dy)) | |
| # Sort shifts by index | |
| shifts_sorted = [shift_coords_seen[i] for i in sorted(shift_coords_seen.keys())] | |
| # DEBUG: Check if shifts were reconstructed | |
| if batch_idx == 0: | |
| print(f"[DEBUG] Reconstructed shifts: {shifts_sorted}") | |
| print(f"[DEBUG] num_shifts: {num_shifts}") | |
| # Create TTA metadata | |
| tta_metadata = { | |
| 'shifts': shifts_sorted, | |
| 'num_shifts': num_shifts, | |
| 'directions': None, | |
| 'pad_px': None, | |
| 'stride': None | |
| } | |
| # Yield TTA format | |
| yield (None, mask, None, fundamentals_dict, paths, tta_metadata) | |
| else: | |
| # Non-TTA: Convert to tensors | |
| results = fundamentals_or_results | |
| x0 = torch.from_numpy(results['x0']) | |
| encoded = torch.from_numpy(results['encoded']) | |
| image_samples = torch.from_numpy(results['image_samples']) | |
| latent_samples = torch.from_numpy(results['latent_samples']) | |
| # Yield non-TTA format | |
| yield (None, mask, None, x0, encoded, image_samples, latent_samples, paths) | |
| finally: | |
| # Cleanup | |
| parallel_loader.shutdown() | |
| def _load_fundamentals_iter_dispatcher(self, save_dir, test_dataset, args): | |
| """Dispatcher for _load_fundamentals_iter with legacy mode support. | |
| Args: | |
| save_dir: Directory containing NPY files | |
| test_dataset: Dataset for masks | |
| args: Command-line arguments | |
| Yields: | |
| Batch data (format depends on TTA mode) | |
| """ | |
| # Check if parallel I/O is enabled | |
| legacy_mode = getattr(args, 'legacy_mode', False) | |
| parallel_io_enabled = getattr(args, 'parallel_io', True) | |
| if legacy_mode or not parallel_io_enabled: | |
| # Use original sequential implementation | |
| yield from self._load_fundamentals_iter(save_dir, test_dataset, args) | |
| else: | |
| # Use optimized parallel implementation (Phase 1) | |
| yield from self._load_fundamentals_iter_optimized(save_dir, test_dataset, args) | |
| def _fuse_array(self, stack, fuse_method, axis=-1): | |
| """Efficiently fuse array using partition-based methods (handles NaN). | |
| Args: | |
| stack: np.ndarray to fuse | |
| fuse_method: 'mean', 'median', 'lowest', 'pct25', 'pct75' | |
| axis: Axis to fuse along (default: -1, last dimension) | |
| Returns: | |
| Fused array | |
| """ | |
| if fuse_method == 'mean': | |
| return np.nanmean(stack, axis=axis) | |
| elif fuse_method == 'lowest': | |
| return np.nanmin(stack, axis=axis) | |
| # For partition-based methods, move target axis to the end for easier indexing | |
| # This allows us to use [..., k] pattern like the existing optimized code | |
| if axis != -1 and axis != stack.ndim - 1: | |
| # Move axis to end | |
| stack = np.moveaxis(stack, axis, -1) | |
| moved_axis = True | |
| else: | |
| moved_axis = False | |
| # Handle NaN: replace with inf for partition, then restore | |
| nan_mask = np.isnan(stack) | |
| stack_clean = np.where(nan_mask, np.inf, stack) | |
| n_valid = stack.shape[-1] | |
| if fuse_method == 'median': | |
| # Optimize: use partition for median (much faster than full sort) | |
| if n_valid % 2 == 1: | |
| # Odd number: median is at index (n-1)//2 | |
| k = (n_valid - 1) // 2 | |
| averaged = np.partition(stack_clean, k, axis=-1)[..., k] | |
| else: | |
| # Even number: median is average of two middle elements | |
| k1 = n_valid // 2 - 1 | |
| k2 = n_valid // 2 | |
| part = np.partition(stack_clean, [k1, k2], axis=-1) | |
| averaged = (part[..., k1] + part[..., k2]) / 2.0 | |
| elif fuse_method == 'pct25': | |
| # Optimize: use partition for 25th percentile (much faster than full sort) | |
| k = int(np.ceil(n_valid * 0.25)) - 1 | |
| k = max(0, min(k, n_valid - 1)) # Clamp to valid range | |
| averaged = np.partition(stack_clean, k, axis=-1)[..., k] | |
| elif fuse_method == 'pct75': | |
| # Optimize: use partition for 75th percentile | |
| k = int(np.ceil(n_valid * 0.75)) - 1 | |
| k = max(0, min(k, n_valid - 1)) # Clamp to valid range | |
| averaged = np.partition(stack_clean, k, axis=-1)[..., k] | |
| else: | |
| raise ValueError(f"Unknown fuse_method: {fuse_method}") | |
| # Restore NaN: if all values were NaN, result should be NaN | |
| all_nan_mask = nan_mask.all(axis=-1) | |
| averaged[all_nan_mask] = np.nan | |
| averaged[averaged == np.inf] = np.nan | |
| # Move axis back if we moved it | |
| if moved_axis: | |
| averaged = np.moveaxis(averaged, -1, axis) | |
| return averaged | |
| def _fuse_fundamentals(self, fundamentals_list, fuse_method): | |
| """Fuse per-shift fundamentals using specified method. | |
| Args: | |
| fundamentals_list: List[List[np.ndarray]] - [img_idx][shift_idx] = [1, C, H, W] | |
| fuse_method: 'mean', 'median', 'lowest', 'pct25', 'pct75' | |
| Returns: | |
| np.ndarray: [B, C, H, W] fused fundamentals | |
| """ | |
| import numpy as np | |
| fused = [] | |
| for img_fundamentals in fundamentals_list: # Each image | |
| # Filter out None values and stack: [num_valid_shifts, C, H, W] | |
| valid_shifts = [f.squeeze(0) for f in img_fundamentals if f is not None] | |
| if not valid_shifts: | |
| raise ValueError("No valid shifts found for image") | |
| stack = np.stack(valid_shifts, axis=0) # [num_shifts, C, H, W] | |
| # Fuse across shift dimension (axis=0) using optimized partition method | |
| fused_img = self._fuse_array(stack, fuse_method, axis=0) | |
| fused.append(fused_img) | |
| return np.stack(fused, axis=0) # [B, C, H, W] | |
| def _compute_tta_batched(self, fundamentals_dict, B, num_shifts, shifts, img_size, | |
| aug_size, offset, shift_method, args): | |
| """Phase 2: GPU-batched TTA processing (5-8× faster than sequential). | |
| Batches all (B × num_shifts) shift operations into a single GPU call. | |
| Returns: | |
| tuple: (anomaly_maps_batch, individual_shifts) | |
| """ | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| keys = ['anomaly_geometric', 'anomaly_arithmetic', 'latent_discrepancy', 'image_discrepancy'] | |
| fuse_method = getattr(args, 'fuse_method', 'mean') | |
| # Collect all valid fundamentals | |
| batch_fundamentals = [] # List of (img_idx, shift_idx, position, fundamentals) | |
| for shift_idx, (dx, dy) in enumerate(shifts): | |
| # Calculate position on canvas | |
| y0 = offset + dy | |
| x0_pos = offset + dx | |
| y1 = y0 + img_size | |
| x1 = x0_pos + img_size | |
| # Skip invalid crops | |
| if not (0 <= y0 < y1 <= aug_size and 0 <= x0_pos < x1 <= aug_size): | |
| continue | |
| for img_idx in range(B): | |
| # Get fundamentals from fundamentals_dict | |
| x0_fund = fundamentals_dict['x0'][img_idx][shift_idx] | |
| encoded_fund = fundamentals_dict['encoded'][img_idx][shift_idx] | |
| image_samples_fund = fundamentals_dict['image_samples'][img_idx][shift_idx] | |
| latent_samples_fund = fundamentals_dict['latent_samples'][img_idx][shift_idx] | |
| # Skip if fundamentals are None | |
| if x0_fund is None or encoded_fund is None: | |
| continue | |
| batch_fundamentals.append({ | |
| 'img_idx': img_idx, | |
| 'shift_idx': shift_idx, | |
| 'position': (x0_pos, y0, x1, y1), | |
| 'shift': (dx, dy), | |
| 'x0': x0_fund, | |
| 'encoded': encoded_fund, | |
| 'image_samples': image_samples_fund, | |
| 'latent_samples': latent_samples_fund | |
| }) | |
| if not batch_fundamentals: | |
| # No valid fundamentals - return empty results | |
| anomaly_maps_batch = {k: np.zeros((B, img_size, img_size), dtype=np.float32) for k in keys} | |
| individual_shifts = [[] for _ in range(B)] | |
| return anomaly_maps_batch, individual_shifts | |
| # Stack all fundamentals into batched tensors [N, C, H, W] | |
| x0_batch = torch.from_numpy(np.concatenate([f['x0'] for f in batch_fundamentals], axis=0)).to(self.device) | |
| encoded_batch = torch.from_numpy(np.concatenate([f['encoded'] for f in batch_fundamentals], axis=0)).to(self.device) | |
| image_samples_batch = torch.from_numpy(np.concatenate([f['image_samples'] for f in batch_fundamentals], axis=0)).to(self.device) | |
| latent_samples_batch = torch.from_numpy(np.concatenate([f['latent_samples'] for f in batch_fundamentals], axis=0)).to(self.device) | |
| # Check if GPU ops are enabled (Phase 3) | |
| gpu_ops_enabled = getattr(args, 'gpu_ops', True) | |
| if gpu_ops_enabled: | |
| # Phase 3: GPU-accelerated path (1.5-2× faster) | |
| anomaly_batch = self.calculate_anomaly_maps_batch_gpu( | |
| x0_batch, encoded_batch, image_samples_batch, latent_samples_batch, | |
| crop_size=img_size, | |
| image_diff_clip_max=self.image_diff_clip_max, | |
| latent_diff_clip_max=self.latent_diff_clip_max | |
| ) | |
| # Keep canvas on GPU as torch tensors | |
| all_maps_gpu = {k: torch.full((B, num_shifts, aug_size, aug_size), float('nan'), | |
| device=self.device, dtype=torch.float32) for k in keys} | |
| individual_shifts = [[] for _ in range(B)] | |
| # Track which (img_idx, shift_idx) combinations are valid | |
| valid_shifts_tracker = [[False] * num_shifts for _ in range(B)] | |
| for idx, fund_info in enumerate(batch_fundamentals): | |
| img_idx = fund_info['img_idx'] | |
| shift_idx = fund_info['shift_idx'] | |
| x0_pos, y0, x1, y1 = fund_info['position'] | |
| dx, dy = fund_info['shift'] | |
| # Place on GPU canvas (in-place GPU operation) | |
| for key in keys: | |
| all_maps_gpu[key][img_idx, shift_idx, y0:y1, x0_pos:x1] = anomaly_batch[key][idx] | |
| # Mark this shift as valid | |
| valid_shifts_tracker[img_idx][shift_idx] = True | |
| # Store individual shift data | |
| crop_original = torch.from_numpy(fund_info['x0']).to(self.device) | |
| crop_reconstruction = torch.from_numpy(fund_info['image_samples']).to(self.device) | |
| individual_shifts[img_idx].append({ | |
| 'shift': (dx, dy), | |
| 'position': (x0_pos, y0, x1, y1), | |
| 'arithmetic_combined': anomaly_batch['anomaly_arithmetic'][idx].cpu().numpy(), | |
| 'crop_original': crop_original.detach().cpu(), | |
| 'crop_reconstruction': crop_reconstruction.detach().cpu(), | |
| 'shift_method': shift_method | |
| }) | |
| # Fuse shifts on GPU | |
| all_averaged_maps_gpu = {} | |
| for key in keys: | |
| fused_per_image = [] | |
| for img_idx in range(B): | |
| # Get valid shift indices for this image | |
| shift_maps = all_maps_gpu[key][img_idx] # [num_shifts, H, W] | |
| valid_indices = [i for i, is_valid in enumerate(valid_shifts_tracker[img_idx]) if is_valid] | |
| if not valid_indices: | |
| fused_per_image.append(torch.zeros((aug_size, aug_size), device=self.device)) | |
| continue | |
| valid_shifts = shift_maps[valid_indices] # [num_valid, H, W] | |
| # Fuse on GPU using torch operations | |
| if fuse_method == 'mean': | |
| fused = torch.nanmean(valid_shifts, dim=0) | |
| elif fuse_method == 'median': | |
| fused = torch.nanmedian(valid_shifts, dim=0).values | |
| elif fuse_method == 'pct25': | |
| fused = torch.nanquantile(valid_shifts, 0.25, dim=0) | |
| else: | |
| fused = torch.nanmean(valid_shifts, dim=0) # fallback | |
| fused_per_image.append(fused) | |
| all_averaged_maps_gpu[key] = torch.stack(fused_per_image) # [B, H, W] | |
| # Resize on GPU and convert to numpy only at the end | |
| anomaly_maps_batch = {} | |
| if shift_method == 'interpolate': | |
| for key in keys: | |
| resized = F.interpolate( | |
| all_averaged_maps_gpu[key].unsqueeze(1), # [B, 1, H, W] | |
| size=(img_size, img_size), | |
| mode='bilinear', | |
| align_corners=False | |
| ).squeeze(1) # [B, H, W] | |
| anomaly_maps_batch[key] = resized.cpu().numpy() | |
| else: | |
| for key in keys: | |
| cropped = all_averaged_maps_gpu[key][:, offset:offset+img_size, offset:offset+img_size] | |
| anomaly_maps_batch[key] = cropped.cpu().numpy() | |
| else: | |
| # Legacy CPU path | |
| anomaly_batch = self.calculate_anomaly_maps_batch( | |
| x0_batch, encoded_batch, image_samples_batch, latent_samples_batch, | |
| crop_size=img_size, | |
| image_diff_clip_max=self.image_diff_clip_max, | |
| latent_diff_clip_max=self.latent_diff_clip_max | |
| ) | |
| # Place results on CPU canvas | |
| all_maps = {k: [[None for _ in range(num_shifts)] for _ in range(B)] for k in keys} | |
| individual_shifts = [[] for _ in range(B)] | |
| for idx, fund_info in enumerate(batch_fundamentals): | |
| img_idx = fund_info['img_idx'] | |
| shift_idx = fund_info['shift_idx'] | |
| x0_pos, y0, x1, y1 = fund_info['position'] | |
| dx, dy = fund_info['shift'] | |
| # Place on canvas | |
| for key in keys: | |
| canvas = np.full((aug_size, aug_size), np.nan, dtype=np.float32) | |
| canvas[y0:y1, x0_pos:x1] = anomaly_batch[key][idx] | |
| all_maps[key][img_idx][shift_idx] = canvas | |
| # Store individual shift data | |
| crop_original = torch.from_numpy(fund_info['x0']).to(self.device) | |
| crop_reconstruction = torch.from_numpy(fund_info['image_samples']).to(self.device) | |
| individual_shifts[img_idx].append({ | |
| 'shift': (dx, dy), | |
| 'position': (x0_pos, y0, x1, y1), | |
| 'arithmetic_combined': anomaly_batch['anomaly_arithmetic'][idx].copy(), | |
| 'crop_original': crop_original.detach().cpu(), | |
| 'crop_reconstruction': crop_reconstruction.detach().cpu(), | |
| 'shift_method': shift_method | |
| }) | |
| # Fuse shifts | |
| all_averaged_maps = {} | |
| for key in keys: | |
| all_averaged_maps[key] = [] | |
| for img_idx in range(B): | |
| valid_maps = [m for m in all_maps[key][img_idx] if m is not None] | |
| if not valid_maps: | |
| all_averaged_maps[key].append(np.zeros((aug_size, aug_size), dtype=np.float32)) | |
| continue | |
| stack = np.stack(valid_maps, axis=-1) | |
| averaged = self._fuse_array(stack, fuse_method, axis=-1) | |
| all_averaged_maps[key].append(averaged) | |
| # Resize to original size | |
| if shift_method == 'interpolate': | |
| for key in keys: | |
| stacked = np.stack(all_averaged_maps[key], axis=0) | |
| tensor = torch.from_numpy(stacked).unsqueeze(1).to(self.device) | |
| resized = F.interpolate(tensor, size=(img_size, img_size), mode='bilinear', align_corners=False) | |
| anomaly_maps_batch[key] = resized.squeeze(1).cpu().numpy() | |
| else: | |
| for key in keys: | |
| stacked = np.stack(all_averaged_maps[key], axis=0) | |
| anomaly_maps_batch[key] = stacked[:, offset:offset+img_size, offset:offset+img_size] | |
| return anomaly_maps_batch, individual_shifts | |
| def _compute_tta_sequential(self, fundamentals_dict, B, num_shifts, shifts, img_size, | |
| aug_size, offset, shift_method, args): | |
| """Legacy sequential TTA processing (kept for backward compatibility).""" | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| keys = ['anomaly_geometric', 'anomaly_arithmetic', 'latent_discrepancy', 'image_discrepancy'] | |
| all_maps = {k: [[None for _ in range(num_shifts)] for _ in range(B)] for k in keys} | |
| individual_shifts = [[] for _ in range(B)] | |
| # Process each shift sequentially (81 separate GPU calls) | |
| for shift_idx, (dx, dy) in enumerate(shifts): | |
| y0 = offset + dy | |
| x0_pos = offset + dx | |
| y1 = y0 + img_size | |
| x1 = x0_pos + img_size | |
| if not (0 <= y0 < y1 <= aug_size and 0 <= x0_pos < x1 <= aug_size): | |
| continue | |
| for img_idx in range(B): | |
| x0_fund = fundamentals_dict['x0'][img_idx][shift_idx] | |
| encoded_fund = fundamentals_dict['encoded'][img_idx][shift_idx] | |
| image_samples_fund = fundamentals_dict['image_samples'][img_idx][shift_idx] | |
| latent_samples_fund = fundamentals_dict['latent_samples'][img_idx][shift_idx] | |
| if x0_fund is None or encoded_fund is None: | |
| continue | |
| x0_tensor = torch.from_numpy(x0_fund).to(self.device) | |
| encoded_tensor = torch.from_numpy(encoded_fund).to(self.device) | |
| image_samples_tensor = torch.from_numpy(image_samples_fund).to(self.device) | |
| latent_samples_tensor = torch.from_numpy(latent_samples_fund).to(self.device) | |
| anomaly_batch = self.calculate_anomaly_maps_batch( | |
| x0_tensor, encoded_tensor, image_samples_tensor, latent_samples_tensor, | |
| crop_size=img_size, | |
| image_diff_clip_max=self.image_diff_clip_max, | |
| latent_diff_clip_max=self.latent_diff_clip_max | |
| ) | |
| for key in keys: | |
| canvas = np.full((aug_size, aug_size), np.nan, dtype=np.float32) | |
| canvas[y0:y1, x0_pos:x1] = anomaly_batch[key][0] | |
| all_maps[key][img_idx][shift_idx] = canvas | |
| crop_original = x0_tensor.detach().cpu() | |
| crop_reconstruction = image_samples_tensor.detach().cpu() | |
| individual_shifts[img_idx].append({ | |
| 'shift': (dx, dy), | |
| 'position': (x0_pos, y0, x1, y1), | |
| 'arithmetic_combined': anomaly_batch['anomaly_arithmetic'][0].copy(), | |
| 'crop_original': crop_original, | |
| 'crop_reconstruction': crop_reconstruction, | |
| 'shift_method': shift_method | |
| }) | |
| # Fuse shifts | |
| fuse_method = getattr(args, 'fuse_method', 'mean') | |
| all_averaged_maps = {} | |
| for key in keys: | |
| all_averaged_maps[key] = [] | |
| for img_idx in range(B): | |
| valid_maps = [m for m in all_maps[key][img_idx] if m is not None] | |
| if not valid_maps: | |
| all_averaged_maps[key].append(np.zeros((aug_size, aug_size), dtype=np.float32)) | |
| continue | |
| stack = np.stack(valid_maps, axis=-1) | |
| averaged = self._fuse_array(stack, fuse_method, axis=-1) | |
| all_averaged_maps[key].append(averaged) | |
| # Resize | |
| anomaly_maps_batch = {} | |
| if shift_method == 'interpolate': | |
| for key in keys: | |
| stacked = np.stack(all_averaged_maps[key], axis=0) | |
| tensor = torch.from_numpy(stacked).unsqueeze(1).to(self.device) | |
| resized = F.interpolate(tensor, size=(img_size, img_size), mode='bilinear', align_corners=False) | |
| anomaly_maps_batch[key] = resized.squeeze(1).cpu().numpy() | |
| else: | |
| for key in keys: | |
| stacked = np.stack(all_averaged_maps[key], axis=0) | |
| anomaly_maps_batch[key] = stacked[:, offset:offset+img_size, offset:offset+img_size] | |
| return anomaly_maps_batch, individual_shifts | |
| def _compute_anomaly_maps_iter(self, fundamental_iter, args, diffusion=None): | |
| """Compute anomaly maps from fundamentals. | |
| Consumes: | |
| - 8-tuple (non-TTA): (x, mask, object_cls, x0, encoded, image_samples, latent_samples, paths) | |
| - 6-tuple (TTA with precomputed fundamentals): (x, mask, object_cls, fundamentals_dict, paths, tta_metadata) | |
| Yields: (anomaly_maps_dict, mask, paths, individual_shifts, orig_imgs) | |
| """ | |
| for batch_data in fundamental_iter: | |
| # Detect tuple format (6-tuple for TTA with fundamentals, 8-tuple for non-TTA or TTA without fundamentals) | |
| if len(batch_data) == 6: | |
| # TTA mode with precomputed fundamentals | |
| x, mask, object_cls, fundamentals_dict, paths, tta_metadata = batch_data | |
| has_precomputed_fundamentals = True | |
| else: | |
| # Non-TTA mode or TTA without precomputed fundamentals | |
| x, mask, object_cls, x0, encoded, image_samples, latent_samples, paths = batch_data | |
| has_precomputed_fundamentals = False | |
| tta_start = time.time() | |
| # Check for shutdown request | |
| if self._shutdown_requested: | |
| print(f"\n⚠️ Anomaly map computation interrupted. Shutting down...") | |
| break | |
| individual_shifts = None | |
| orig_imgs = None | |
| # Check if TTA is enabled | |
| if args.pad_px is not None and has_precomputed_fundamentals: | |
| # TTA with precomputed fundamentals - compute anomaly maps from cached fundamentals | |
| orig_imgs = x # May be None when loading from cache | |
| # Get batch size from fundamentals_dict structure (since x may be None) | |
| B = len(fundamentals_dict['x0']) | |
| # Get image size from first fundamental (all should have same size) | |
| first_fund = None | |
| for img_idx in range(B): | |
| for shift_idx in range(tta_metadata['num_shifts']): | |
| if fundamentals_dict['x0'][img_idx][shift_idx] is not None: | |
| first_fund = fundamentals_dict['x0'][img_idx][shift_idx] | |
| break | |
| if first_fund is not None: | |
| break | |
| if first_fund is None: | |
| raise ValueError("No valid fundamentals found in batch") | |
| # Get image size from fundamental shape [1, C, H, W] | |
| # Fundamentals are saved with batch dimension, so shape[2] is height | |
| img_size = first_fund.shape[2] | |
| # Get metadata from tta_metadata | |
| num_shifts = tta_metadata['num_shifts'] | |
| shifts = tta_metadata['shifts'] | |
| # Determine shift_method and img_enlarge_px from args | |
| # Use pad_px from args (not tta_metadata which may be None) | |
| shift_method = getattr(args, 'shift_method', 'interpolate') | |
| img_enlarge_px = getattr(args, 'img_enlarge_px', 4) | |
| pad_px_val = args.pad_px | |
| # Calculate aug_size based on shift_method | |
| if shift_method == 'interpolate': | |
| aug_size = img_size + 2 * img_enlarge_px | |
| offset = img_enlarge_px | |
| elif shift_method == 'mirror': | |
| aug_size = img_size + 2 * pad_px_val | |
| offset = pad_px_val | |
| else: | |
| raise ValueError(f"Unknown shift_method: {shift_method}") | |
| # Check if batched TTA is enabled (Phase 2 optimization) | |
| batched_tta_enabled = getattr(args, 'batched_tta', True) | |
| if batched_tta_enabled: | |
| # Phase 2: GPU-batched TTA processing (5-8× faster) | |
| anomaly_maps_batch, individual_shifts = self._compute_tta_batched( | |
| fundamentals_dict, B, num_shifts, shifts, img_size, aug_size, | |
| offset, shift_method, args | |
| ) | |
| else: | |
| # Legacy: Sequential TTA processing | |
| anomaly_maps_batch, individual_shifts = self._compute_tta_sequential( | |
| fundamentals_dict, B, num_shifts, shifts, img_size, aug_size, | |
| offset, shift_method, args | |
| ) | |
| elif args.pad_px is not None: | |
| # Use TTA path with _anomaly_shift_avg_combined_batch | |
| # Prepare model kwargs | |
| if isinstance(object_cls, torch.Tensor): | |
| object_cls_tensor = object_cls.to(self.device) | |
| if len(object_cls_tensor.shape) == 1: | |
| object_cls_tensor = object_cls_tensor.unsqueeze(1) | |
| else: | |
| if isinstance(object_cls, (list, tuple)): | |
| object_cls_tensor = torch.tensor(object_cls, device=self.device, dtype=torch.long) | |
| else: | |
| object_cls_tensor = torch.tensor([object_cls] * x.shape[0], device=self.device, dtype=torch.long) | |
| if len(object_cls_tensor.shape) == 1: | |
| object_cls_tensor = object_cls_tensor.unsqueeze(1) | |
| if object_cls_tensor.dtype != torch.long: | |
| object_cls_tensor = object_cls_tensor.long() | |
| model_kwargs = {'context': object_cls_tensor, 'mask': None} | |
| # Store original images for plot_shift_analysis | |
| orig_imgs = x | |
| # Use TTA method | |
| tta_compute_start = time.time() | |
| anomaly_maps_batch, individual_shifts = self._anomaly_shift_avg_combined_batch( | |
| x, | |
| model_kwargs=model_kwargs, | |
| diffusion=diffusion, | |
| reverse_steps=args.reverse_steps, | |
| pad_px=args.pad_px, | |
| img_enlarge_px=getattr(args, 'img_enlarge_px', 4), | |
| stride=getattr(args, 'stride', 1), | |
| directions=getattr(args, 'directions', ('h', 'v', 'diag')), | |
| shift_method=getattr(args, 'shift_method', 'interpolate'), | |
| fuse_method=getattr(args, 'fuse_method', 'mean'), | |
| tta_batch_size=getattr(args, 'tta_batch_size', 8), | |
| eta=0.0, | |
| crop_size=args.crop_size, | |
| ) | |
| tta_compute_time = time.time() - tta_compute_start | |
| if self.debug_enabled: | |
| self.log_message(f" [TIMING] TTA computation total: {tta_compute_time:.3f}s", level='debug') | |
| else: | |
| # Standard path: compute anomaly maps using existing method | |
| standard_start = time.time() | |
| anomaly_maps_batch = self.calculate_anomaly_maps_batch( | |
| x0, encoded, image_samples, latent_samples, args.crop_size, | |
| image_diff_clip_max=self.image_diff_clip_max, | |
| latent_diff_clip_max=self.latent_diff_clip_max | |
| ) | |
| standard_time = time.time() - standard_start | |
| if self.debug_enabled: | |
| self.log_message(f" [TIMING] Standard anomaly calc: {standard_time:.3f}s", level='debug') | |
| tta_total_time = time.time() - tta_start | |
| if self.debug_enabled: | |
| self.log_message(f" [TIMING] _compute_anomaly_maps_iter total: {tta_total_time:.3f}s", level='debug') | |
| # Yield derived outputs | |
| yield (anomaly_maps_batch, mask, paths, individual_shifts, orig_imgs) | |
| def _aggregate_and_compute_metrics(self, anomaly_map_iter, args): | |
| """Aggregate anomaly maps and compute metrics. | |
| Consumes: (anomaly_maps_dict, mask, paths, individual_shifts, orig_imgs) | |
| Returns: (all_anomaly_maps, all_masks, all_labels) | |
| """ | |
| from collections import defaultdict | |
| from tqdm import tqdm | |
| import numpy as np | |
| # Initialize storage | |
| all_anomaly_maps = defaultdict(list) | |
| all_masks = [] | |
| all_labels = [] | |
| all_image_paths = [] | |
| batch_idx = 0 | |
| # Accumulate | |
| for (anomaly_maps_batch, mask, paths, individual_shifts, orig_imgs) in tqdm(anomaly_map_iter, desc="Evaluating"): | |
| # Check for shutdown request | |
| if self._shutdown_requested: | |
| print(f"\n⚠️ Evaluation interrupted at batch {batch_idx}. Shutting down...") | |
| break | |
| for key in anomaly_maps_batch.keys(): | |
| all_anomaly_maps[key].append(anomaly_maps_batch[key]) | |
| all_masks.append(mask) | |
| all_labels.append(mask) # Image-level labels | |
| all_image_paths.extend(paths) | |
| # Plot shift analysis if TTA is enabled (similar to decodiff_evaluator.py) | |
| if args.pad_px is not None and individual_shifts is not None and orig_imgs is not None and self.save_plot_path: | |
| try: | |
| # Reorganize individual_shifts from per-image structure to per-shift structure | |
| # Current: individual_shifts[img_idx] = list of shift dicts | |
| # Expected: individual_shifts = list of shift dicts, each with arithmetic_combined [B, H, W] | |
| if len(individual_shifts) > 0 and len(individual_shifts[0]) > 0: | |
| batch_size = len(individual_shifts) | |
| num_shifts = len(individual_shifts[0]) | |
| # Reorganize: group by shift instead of by image | |
| shifts_for_plot = [] | |
| for shift_idx in range(num_shifts): | |
| # Collect anomaly maps for all images for this shift | |
| arithmetic_maps = [] | |
| crop_originals = [] | |
| crop_reconstructions = [] | |
| # Get shift info from first image (all images have same shifts) | |
| first_shift_data = individual_shifts[0][shift_idx] | |
| shift = first_shift_data['shift'] | |
| position = first_shift_data['position'] | |
| shift_method = first_shift_data.get('shift_method', 'interpolate') | |
| # Collect data from all images for this shift | |
| for img_idx in range(batch_size): | |
| if shift_idx < len(individual_shifts[img_idx]): | |
| shift_data = individual_shifts[img_idx][shift_idx] | |
| # arithmetic_combined is [H, W] for this image+shift | |
| arithmetic_maps.append(shift_data['arithmetic_combined']) | |
| crop_originals.append(shift_data.get('crop_original')) | |
| crop_reconstructions.append(shift_data.get('crop_reconstruction')) | |
| # Stack to create [B, H, W] | |
| if arithmetic_maps: | |
| arithmetic_combined_batch = np.stack(arithmetic_maps, axis=0) # [B, H, W] | |
| shifts_for_plot.append({ | |
| 'shift': shift, | |
| 'position': position, | |
| 'arithmetic_combined': arithmetic_combined_batch, # [B, H, W] | |
| 'crop_original': crop_originals[0] if crop_originals else None, # Use first for display | |
| 'crop_reconstruction': crop_reconstructions[0] if crop_reconstructions else None, | |
| 'shift_method': shift_method | |
| }) | |
| # Get averaged anomaly map for entire batch | |
| averaged_anomaly_map = anomaly_maps_batch['anomaly_arithmetic'] | |
| # Get fuse_method from args for display in plot title | |
| fuse_method = getattr(args, 'fuse_method', 'mean') | |
| if not self._disable_file_saving: | |
| # Queue plot saving for background thread (non-blocking) | |
| if self.save_queue: | |
| # Convert tensors to numpy/cpu before queuing to avoid GPU memory issues | |
| orig_imgs_cpu = orig_imgs.detach().cpu().numpy() if isinstance(orig_imgs, torch.Tensor) else orig_imgs | |
| averaged_anomaly_map_cpu = averaged_anomaly_map.copy() if isinstance(averaged_anomaly_map, np.ndarray) else averaged_anomaly_map.detach().cpu().numpy() | |
| # Deep copy shifts_for_plot to avoid issues with shared references | |
| shifts_for_plot_cpu = copy.deepcopy(shifts_for_plot) | |
| # Convert any tensors in shift data to numpy | |
| for shift_data in shifts_for_plot_cpu: | |
| if 'crop_original' in shift_data and shift_data['crop_original'] is not None: | |
| if isinstance(shift_data['crop_original'], torch.Tensor): | |
| shift_data['crop_original'] = shift_data['crop_original'].detach().cpu().numpy() | |
| if 'crop_reconstruction' in shift_data and shift_data['crop_reconstruction'] is not None: | |
| if isinstance(shift_data['crop_reconstruction'], torch.Tensor): | |
| shift_data['crop_reconstruction'] = shift_data['crop_reconstruction'].detach().cpu().numpy() | |
| self.save_queue.put({ | |
| 'kind': 'plot', | |
| 'individual_shifts': shifts_for_plot_cpu, | |
| 'averaged_anomaly_map': averaged_anomaly_map_cpu, | |
| 'batch_idx': batch_idx, | |
| 'orig_imgs': orig_imgs_cpu, | |
| 'save_path': self.save_plot_path, | |
| 'max_shifts_to_show': 16, | |
| 'fuse_method': fuse_method | |
| }) | |
| else: | |
| # Fallback to synchronous if queue not available | |
| PR.plot_shift_analysis( | |
| individual_shifts=shifts_for_plot, | |
| averaged_anomaly_map=averaged_anomaly_map, | |
| batch_idx=batch_idx, | |
| orig_imgs=orig_imgs, | |
| save_path=self.save_plot_path, | |
| max_shifts_to_show=16, | |
| fuse_method=fuse_method | |
| ) | |
| else: | |
| print(f" [SKIP] Shift analysis plot (file saving disabled)") | |
| except Exception as e: | |
| print(f"Warning: Failed to plot shift analysis: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| batch_idx += 1 | |
| # Concatenate | |
| for key in all_anomaly_maps.keys(): | |
| all_anomaly_maps[key] = np.concatenate(all_anomaly_maps[key], axis=0) | |
| all_masks = np.concatenate(all_masks, axis=0) | |
| all_labels = np.concatenate(all_labels, axis=0) | |
| # Calculate and print metrics | |
| for key in all_anomaly_maps.keys(): | |
| metrics = self.calculate_metrics(all_masks, all_anomaly_maps[key]) | |
| self.log_message(f"\n{key} metrics:") | |
| self.log_message(f" Pixel AUROC: {metrics[0]:.4f}") | |
| self.log_message(f" Pixel AUPRO: {metrics[1]:.4f}") | |
| self.log_message(f" Pixel F1: {metrics[2]:.4f}") | |
| self.log_message(f" Pixel AP: {metrics[3]:.4f}") | |
| self.log_message(f" Image AUROC: {metrics[4]:.4f}") | |
| self.log_message(f" Image AP: {metrics[5]:.4f}") | |
| self.log_message(f" Image F1: {metrics[6]:.4f}") | |
| # Save evaluation images if enabled | |
| if getattr(args, 'save_evaluation_images', False) and self.checkpoint_manager: | |
| category = getattr(args, 'object_category', 'unknown') | |
| annotation_dir = getattr(args, 'annotation_dir', None) | |
| npy_cache_dir = getattr(args, 'npy_cache_dir', None) | |
| patch_size = getattr(args, 'patch_size', 256) | |
| # Map CLI parameters to evaluation parameters | |
| # anomaly_threshold -> anomaly_binary_threshold (threshold for binary mask) | |
| # anomaly_min_area -> anomaly_pixel_num_threshold (minimum pixels for defect) | |
| anomaly_binary_threshold = getattr(args, 'anomaly_threshold', getattr(args, 'anomaly_binary_threshold', 5)) | |
| anomaly_pixel_num_threshold = getattr(args, 'anomaly_min_area', getattr(args, 'anomaly_pixel_num_threshold', 10)) | |
| overlay_alpha = getattr(args, 'image_overlay_alpha', 0.8) | |
| grid_thickness = getattr(args, 'grid_thickness', 1) | |
| self.save_evaluation_images_for_category( | |
| anomaly_maps=all_anomaly_maps, | |
| annotation_dir=annotation_dir, | |
| category=category, | |
| image_paths=all_image_paths, | |
| npy_cache_dir=npy_cache_dir, | |
| patch_size=patch_size, | |
| anomaly_binary_threshold=anomaly_binary_threshold, | |
| anomaly_pixel_num_threshold=anomaly_pixel_num_threshold, | |
| overlay_alpha=overlay_alpha, | |
| grid_thickness=grid_thickness | |
| ) | |
| # Create confusion matrix if enabled | |
| if getattr(args, 'enable_confusion_matrix', False): | |
| annotation_dir = getattr(args, 'annotation_dir', None) | |
| if not annotation_dir: | |
| self.log_message('Warning: Cannot generate confusion matrix without annotation_dir') | |
| else: | |
| self.log_message('\nGenerating confusion matrix...') | |
| try: | |
| import json | |
| patch_results_by_image = {} | |
| # If evaluation images were saved, load from JSONs | |
| if getattr(args, 'save_evaluation_images', False) and self.checkpoint_manager: | |
| evaluation_results_dir = self.checkpoint_manager.evaluation_results_dir | |
| if evaluation_results_dir.exists(): | |
| for json_file in evaluation_results_dir.glob('*__evaluation.json'): | |
| with open(json_file, 'r') as f: | |
| eval_data = json.load(f) | |
| image_path = eval_data.get('image_path', '') | |
| patch_analysis = eval_data.get('patch_analysis', []) | |
| if image_path and patch_analysis: | |
| patch_results_by_image[image_path] = patch_analysis | |
| # If no saved JSONs, compute patch results on the fly | |
| if not patch_results_by_image: | |
| predictions = all_anomaly_maps['anomaly_arithmetic'] | |
| ground_truth_map = self.load_ground_truth_map(annotation_dir) | |
| for idx, image_path in enumerate(all_image_paths): | |
| pred_mask = predictions[idx] | |
| if isinstance(pred_mask, torch.Tensor): | |
| pred_mask = pred_mask.detach().cpu().numpy() | |
| ground_truth_defective_patches = self.get_ground_truth_for_image( | |
| ground_truth_map, image_path | |
| ) | |
| patch_results = self.classify_patches( | |
| pred_mask, ground_truth_defective_patches, patch_size, | |
| anomaly_binary_threshold, anomaly_pixel_num_threshold | |
| ) | |
| patch_results_by_image[image_path] = patch_results | |
| if patch_results_by_image: | |
| # Determine output directory | |
| if self.save_plot_path: | |
| output_dir = str(self.save_plot_path) | |
| elif self.checkpoint_manager: | |
| output_dir = str(self.checkpoint_manager.results_dir) | |
| else: | |
| output_dir = './confusion_matrix_output' | |
| create_confusion_matrix_from_patch_results( | |
| patch_results_by_image=patch_results_by_image, | |
| output_dir=output_dir, | |
| patch_size=patch_size | |
| ) | |
| else: | |
| self.log_message('Warning: No evaluation results found for confusion matrix generation') | |
| except Exception as e: | |
| self.log_message(f'Warning: Failed to create confusion matrix: {e}') | |
| import traceback | |
| traceback.print_exc() | |
| # Save visualization variants if enabled | |
| save_image_variants = getattr(args, 'save_image_variants', None) | |
| if save_image_variants and self.save_plot_path: | |
| try: | |
| from dioodmi.utils.utils import path_to_safe_filename | |
| from dioodmi.utils.visualization_variants import create_anomaly_visualization_variants | |
| # Set defaults | |
| if save_image_variants is None: | |
| save_image_variants = ['continuous', 'binary'] | |
| save_colormaps = getattr(args, 'save_colormaps', None) or ['jet'] | |
| save_normalization = getattr(args, 'save_normalization', 'minmax') | |
| normal_center = getattr(args, 'normal_center', 125.0) | |
| save_binary_thresholds = getattr(args, 'save_binary_thresholds', None) or [5.0] | |
| # Handle 'all' variant option | |
| if 'all' in save_image_variants: | |
| save_image_variants = ['continuous', 'binary', 'grayscale', 'absolute'] | |
| # Create image_level directory for variants | |
| image_level_dir = Path(self.save_plot_path) / "marked_images" / "image_level" | |
| image_level_dir.mkdir(parents=True, exist_ok=True) | |
| # Save variants for each image | |
| for idx, image_path in enumerate(all_image_paths): | |
| safe_name = path_to_safe_filename(image_path) | |
| # Save variants for each map type | |
| for map_type in ['anomaly_arithmetic', 'anomaly_geometric', 'latent_discrepancy', 'image_discrepancy']: | |
| if map_type in all_anomaly_maps and idx < len(all_anomaly_maps[map_type]): | |
| map_data = all_anomaly_maps[map_type][idx] | |
| # Ensure it's a numpy array | |
| if isinstance(map_data, torch.Tensor): | |
| map_data = map_data.detach().cpu().numpy() | |
| # For grayscale variant, use signed difference if available | |
| signed_diff_data = None | |
| if 'grayscale' in save_image_variants: | |
| if map_type == 'image_discrepancy' and 'image_signed_diff' in all_anomaly_maps: | |
| if idx < len(all_anomaly_maps['image_signed_diff']): | |
| signed_diff_data = all_anomaly_maps['image_signed_diff'][idx] | |
| if isinstance(signed_diff_data, torch.Tensor): | |
| signed_diff_data = signed_diff_data.detach().cpu().numpy() | |
| elif map_type == 'latent_discrepancy' and 'latent_signed_diff' in all_anomaly_maps: | |
| if idx < len(all_anomaly_maps['latent_signed_diff']): | |
| signed_diff_data = all_anomaly_maps['latent_signed_diff'][idx] | |
| if isinstance(signed_diff_data, torch.Tensor): | |
| signed_diff_data = signed_diff_data.detach().cpu().numpy() | |
| # Create variants | |
| variants = create_anomaly_visualization_variants( | |
| map_data, | |
| variants=save_image_variants, | |
| colormaps=save_colormaps, | |
| normalization=save_normalization, | |
| normal_center=normal_center, | |
| thresholds=save_binary_thresholds, | |
| signed_diff_data=signed_diff_data # Pass signed difference for grayscale | |
| ) | |
| # Save each variant | |
| for variant_name, variant_img in variants.items(): | |
| variant_path = image_level_dir / f"{safe_name}__{map_type}__{variant_name}.png" | |
| PILImage.fromarray(variant_img).save(variant_path) | |
| except Exception as e: | |
| print(f"Warning: Failed to save visualization variants: {e}") | |
| # Save plots if requested | |
| if self.save_plot_path: | |
| self.log_message(f"Saving plots to {self.save_plot_path}") | |
| # TODO: Implement plot saving | |
| pass | |
| # Wait for all images to be saved before returning | |
| if self.save_plot_path and self.save_queue and not self._disable_file_saving: | |
| self.wait_for_saves() | |
| # Cleanup CUDA resources if shutdown was requested | |
| if self._shutdown_requested: | |
| try: | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| torch.cuda.synchronize() | |
| except Exception: | |
| pass # Ignore errors during cleanup | |
| print("Cleanup complete. Exiting...") | |
| return all_anomaly_maps, all_masks, all_labels | |
| def detect_and_evaluate_without_cache(self, args): | |
| """Mode 1: Fast evaluation without caching (original evaluate behavior). | |
| Composition: Inference → Process → Aggregate | |
| """ | |
| # Mode detection for full_image_eval | |
| if getattr(args, "full_image_eval", False): | |
| return self._evaluate_full_image_tiled(args, getattr(args, 'resume', False)) | |
| # Setup | |
| test_dataset, test_loader, diffusion = self._setup_evaluation(args) | |
| # Composition: Inference → Process → Aggregate | |
| inference_iter = self._run_inference_iter(args, test_dataset, test_loader, diffusion) | |
| anomaly_map_iter = self._compute_anomaly_maps_iter(inference_iter, args, diffusion) | |
| all_anomaly_maps, all_masks, all_labels = self._aggregate_and_compute_metrics(anomaly_map_iter, args) | |
| return all_anomaly_maps, all_masks, all_labels | |
| def detect_and_evaluate(self, args): | |
| """Mode 2: Standard workflow with detection result caching (DEFAULT). | |
| Composition: Inference → Save → Process → Aggregate | |
| """ | |
| import os | |
| # Setup | |
| test_dataset, test_loader, diffusion = self._setup_evaluation(args) | |
| save_dir = args.npy_cache_dir or os.path.join(args.save_plot_path, "npy_cache") | |
| os.makedirs(save_dir, exist_ok=True) | |
| # Composition: Inference → Save → Process → Aggregate | |
| inference_iter = self._run_inference_iter(args, test_dataset, test_loader, diffusion) | |
| saving_iter = self._save_fundamentals_iter(inference_iter, save_dir, args) | |
| anomaly_map_iter = self._compute_anomaly_maps_iter(saving_iter, args, diffusion) | |
| all_anomaly_maps, all_masks, all_labels = self._aggregate_and_compute_metrics(anomaly_map_iter, args) | |
| self.log_message(f"\n✓ Detection results cached to: {save_dir}") | |
| self.log_message(f"✓ Use --mode evaluate_only --npy_cache_dir {save_dir} to reprocess") | |
| return all_anomaly_maps, all_masks, all_labels | |
| def detect_only(self, args): | |
| """Mode 3: Detection only, save results for later evaluation. | |
| Composition: Inference → Save | |
| TTA fundamentals are now computed in _run_inference_iter and saved by _save_fundamentals_iter. | |
| """ | |
| import os | |
| # Setup | |
| test_dataset, test_loader, diffusion = self._setup_evaluation(args) | |
| save_dir = args.npy_cache_dir or os.path.join(args.save_plot_path, "npy_cache") | |
| os.makedirs(save_dir, exist_ok=True) | |
| # Composition: Inference → Save (TTA fundamentals computed and saved inline) | |
| inference_iter = self._run_inference_iter(args, test_dataset, test_loader, diffusion) | |
| saving_iter = self._save_fundamentals_iter(inference_iter, save_dir, args) | |
| # Consume iterator to trigger saving | |
| count = sum(1 for _ in saving_iter) | |
| self.log_message(f"\n✓ Detection completed: {count} batches processed") | |
| self.log_message(f"✓ Results saved to: {save_dir}") | |
| self.log_message(f"✓ Use --mode evaluate_only --npy_cache_dir {save_dir} to compute metrics") | |
| def evaluate_only(self, args): | |
| """Mode 4: Evaluate previously saved detection results. | |
| Composition: Load → Process → Aggregate | |
| """ | |
| import os | |
| # Validate | |
| save_dir = args.npy_cache_dir | |
| if not os.path.exists(save_dir): | |
| raise ValueError(f"NPY cache directory not found: {save_dir}") | |
| # Setup (need dataset for ground truth masks) | |
| test_dataset, test_loader, _ = self._setup_evaluation(args) | |
| # Composition: Load → Process → Aggregate | |
| loading_iter = self._load_fundamentals_iter_dispatcher(save_dir, test_dataset, args) | |
| # For evaluate_only mode, we don't have diffusion from setup, so create it | |
| _, _, diffusion = self._setup_evaluation(args) | |
| anomaly_map_iter = self._compute_anomaly_maps_iter(loading_iter, args, diffusion) | |
| all_anomaly_maps, all_masks, all_labels = self._aggregate_and_compute_metrics(anomaly_map_iter, args) | |
| self.log_message(f"\n✓ Evaluation completed using cached results from: {save_dir}") | |
| self.log_message(f"✓ Processed {len(all_masks)} images") | |
| return all_anomaly_maps, all_masks, all_labels | |
| def evaluate(self, args): | |
| """ | |
| Main evaluation method with mode dispatch. | |
| This method processes the entire test dataset and generates anomaly detection results. | |
| Supports 4 modes for different workflows: | |
| - detect_and_evaluate_without_cache: Fast, no caching | |
| - detect_and_evaluate: Standard with NPY cache (DEFAULT) | |
| - detect_only: Save fundamentals for later | |
| - evaluate_only: Reprocess cached fundamentals | |
| Also supports full_image_eval mode for tiled evaluation of large images. | |
| """ | |
| # Wrap in try-except to catch KeyboardInterrupt | |
| try: | |
| # Get mode (default to detect_and_evaluate for backward compatibility) | |
| mode = getattr(args, 'mode', 'detect_and_evaluate') | |
| # Mode dispatch | |
| if mode == "detect_and_evaluate_without_cache": | |
| # Fast pipeline without caching | |
| if getattr(args, 'full_image_eval', False): | |
| return self._evaluate_full_image_tiled(args, getattr(args, 'enable_resume', False)) | |
| else: | |
| return self.detect_and_evaluate_without_cache(args) | |
| elif mode == "detect_and_evaluate": | |
| # Standard workflow with caching (DEFAULT) | |
| if getattr(args, 'full_image_eval', False): | |
| raise NotImplementedError("detect_and_evaluate mode not supported with full_image_eval yet") | |
| else: | |
| return self.detect_and_evaluate(args) | |
| elif mode == "detect_only": | |
| # Detection only, save for later evaluation | |
| if getattr(args, 'full_image_eval', False): | |
| raise NotImplementedError("detect_only mode not supported with full_image_eval yet") | |
| else: | |
| return self.detect_only(args) | |
| elif mode == "evaluate_only": | |
| # Evaluate previously saved detection results | |
| if getattr(args, 'full_image_eval', False): | |
| raise NotImplementedError("evaluate_only mode not supported with full_image_eval yet") | |
| else: | |
| return self.evaluate_only(args) | |
| else: | |
| raise ValueError( | |
| f"Unknown mode: {mode}. Must be one of: " | |
| "detect_and_evaluate_without_cache, detect_and_evaluate, detect_only, evaluate_only" | |
| ) | |
| except KeyboardInterrupt: | |
| # Catch KeyboardInterrupt explicitly (may be raised during blocking operations) | |
| # This is a fallback - the signal handler should have already set _shutdown_requested | |
| if not self._shutdown_requested: | |
| self.log_message("\n⚠️ KeyboardInterrupt caught. Initiating graceful shutdown...", level='warning') | |
| self._shutdown_requested = True | |
| finally: | |
| # Cleanup CUDA resources | |
| if self._shutdown_requested: | |
| try: | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| torch.cuda.synchronize() | |
| except Exception: | |
| pass # Ignore errors during cleanup | |
| # Wait for pending saves before shutdown (prevents data loss) | |
| if self.save_plot_path and self.save_queue: | |
| queue_size = self.get_save_queue_size() | |
| if queue_size > 0: | |
| self.log_message(f"⏳ Waiting for {queue_size} pending save operations to complete...") | |
| self.wait_for_saves(timeout=30) # Shorter timeout for interrupts | |
| # Shutdown save worker if running | |
| if self.save_queue and self.save_worker_thread: | |
| self.shutdown_save_worker() | |
| self.log_message("Evaluation interrupted by user. Cleanup complete.") | |
| # Evaluation image saving methods (ported from decodiff_evaluator.py) | |
| def create_binary_mask(self, anomaly_map, threshold=5): | |
| """Create binary mask from anomaly map using threshold (0-255 range)""" | |
| # Auto-detect anomaly map range and normalize threshold accordingly | |
| amap_max = np.max(anomaly_map) | |
| if amap_max <= 1.0: | |
| # Anomaly map is in 0-1 range, normalize threshold from 0-255 to 0-1 | |
| normalized_threshold = threshold / 255.0 | |
| else: | |
| # Anomaly map is in 0-255 range, use threshold as-is | |
| normalized_threshold = threshold | |
| return (anomaly_map > normalized_threshold).astype(np.float32) | |
| def count_anomaly_pixels(self, binary_mask): | |
| """Count number of positive pixels in binary mask""" | |
| return int(np.sum(binary_mask)) | |
| def load_single_annotation(self, annotation_path): | |
| """Load a single annotation file and return its data | |
| Args: | |
| annotation_path: Path to the annotation JSON file | |
| Returns: | |
| dict: Annotation data with keys 'image_path', 'defective_patches', 'grid_size' | |
| None: If file doesn't exist or can't be loaded | |
| """ | |
| import os | |
| import json | |
| if not os.path.exists(annotation_path): | |
| return None | |
| try: | |
| with open(annotation_path, 'r') as f: | |
| annotation = json.load(f) | |
| return annotation | |
| except Exception as e: | |
| print(f"Warning: Error reading annotation file {annotation_path}: {e}") | |
| return None | |
| def load_ground_truth_map(self, annotation_dir): | |
| """Load all annotations and build ground truth map | |
| Args: | |
| annotation_dir: Directory containing annotation JSON files | |
| Returns: | |
| dict: Mapping from image_path to set of (grid_row, grid_col) tuples | |
| """ | |
| import os | |
| from glob import glob | |
| ground_truth_map = {} | |
| if not annotation_dir or not os.path.exists(annotation_dir): | |
| return ground_truth_map | |
| # Find all annotation files | |
| annotation_pattern = os.path.join(annotation_dir, "*__annotations.json") | |
| annotation_files = glob(annotation_pattern) | |
| for annotation_file in annotation_files: | |
| annotation = self.load_single_annotation(annotation_file) | |
| if annotation is None: | |
| continue | |
| # Get the original image path from the annotation | |
| image_path = annotation.get("image_path") | |
| if image_path: | |
| # Convert defective patches to set of tuples | |
| defective_patches = set(tuple(patch) for patch in annotation.get("defective_patches", [])) | |
| ground_truth_map[image_path] = defective_patches | |
| return ground_truth_map | |
| def get_ground_truth_for_image(self, ground_truth_map, image_path): | |
| """Get defective patches for a specific image | |
| Args: | |
| ground_truth_map: Dictionary mapping image paths to defective patch sets | |
| image_path: Path to the image | |
| Returns: | |
| set: Set of (grid_row, grid_col) tuples for defective patches | |
| """ | |
| return ground_truth_map.get(image_path, set()) | |
| def generate_patch_coordinates(self, height, width, patch_size): | |
| """Generate 8-value coordinates for all patches in an image | |
| Args: | |
| height: Image height | |
| width: Image width | |
| patch_size: Size of each patch | |
| Returns: | |
| List of tuples: [(grid_row, grid_col, coords_8_values), ...] | |
| where coords_8_values = [x1, y1, x2, y2, x3, y3, x4, y4] | |
| representing the 4 corners of each patch | |
| """ | |
| n_rows = (height + patch_size - 1) // patch_size | |
| n_cols = (width + patch_size - 1) // patch_size | |
| patch_coordinates = [] | |
| for r in range(n_rows): | |
| for c in range(n_cols): | |
| # Calculate patch boundaries | |
| y1 = r * patch_size | |
| y2 = min(y1 + patch_size, height) | |
| x1 = c * patch_size | |
| x2 = min(x1 + patch_size, width) | |
| # Create 8-value coordinates for axis-aligned rectangle | |
| # Format: [x1, y1, x2, y2, x3, y3, x4, y4] (4 corners) | |
| coords_8_values = [ | |
| x1, y1, # Top-left | |
| x2, y1, # Top-right | |
| x2, y2, # Bottom-right | |
| x1, y2 # Bottom-left | |
| ] | |
| patch_coordinates.append((r, c, coords_8_values)) | |
| return patch_coordinates | |
| def extract_patch(self, image, coords_8_values): | |
| """Extract a single patch from image using 8-value coordinates | |
| Args: | |
| image: Input image as numpy array | |
| coords_8_values: List of 8 values [x1, y1, x2, y2, x3, y3, x4, y4] | |
| Returns: | |
| Extracted patch as numpy array | |
| """ | |
| x1, y1, x2, y2, x3, y3, x4, y4 = coords_8_values | |
| # Check if patch is parallel to image axes (faster extraction) | |
| is_parallel = (y1 == y2 and y3 == y4 and x1 == x4 and x2 == x3) | |
| if is_parallel: | |
| # Fast path: rectangular patch aligned with image axes | |
| # Use top-left and bottom-right corners | |
| x_min, y_min = int(x1), int(y1) | |
| x_max, y_max = int(x3), int(y3) | |
| # Ensure coordinates are within image bounds | |
| height, width = image.shape[:2] | |
| x_min = max(0, min(x_min, width - 1)) | |
| y_min = max(0, min(y_min, height - 1)) | |
| x_max = max(x_min + 1, min(x_max, width)) | |
| y_max = max(y_min + 1, min(y_max, height)) | |
| # Extract patch using array slicing | |
| patch = image[y_min:y_max, x_min:x_max] | |
| return patch | |
| else: | |
| # TODO: Slow path for rotated patches - implement perspective transform | |
| # For now, raise an error to indicate this is not yet supported | |
| raise NotImplementedError("Rotated/perspective patches not yet supported") | |
| def classify_single_patch(self, pred_patch, grid_row, grid_col, ground_truth_defective_patches, | |
| anomaly_binary_threshold=5, anomaly_pixel_num_threshold=10): | |
| """Classify a single patch as TP/FP/FN/TN using annotation-based ground truth | |
| Args: | |
| pred_patch: Prediction patch as numpy array | |
| grid_row: Row position in grid | |
| grid_col: Column position in grid | |
| ground_truth_defective_patches: Set of (row, col) tuples representing defective patches | |
| anomaly_binary_threshold: Threshold for creating binary mask (0-255) | |
| anomaly_pixel_num_threshold: Minimum pixels needed for defect classification | |
| Returns: | |
| dict: Classification result with keys: | |
| - status: 'TP', 'FP', 'FN', or 'TN' | |
| - anomaly_pixels: Number of anomaly pixels found | |
| - pred_score: Mean prediction score | |
| - gt_present: Whether ground truth has defects | |
| """ | |
| # Create binary mask and count anomaly pixels | |
| binary_mask = self.create_binary_mask(pred_patch, anomaly_binary_threshold) | |
| anomaly_pixels = self.count_anomaly_pixels(binary_mask) | |
| # Determine if patch is predicted as defective using pixel counting | |
| pred_defective = anomaly_pixels > anomaly_pixel_num_threshold | |
| # Determine if ground truth has defects using annotations | |
| gt_defective = (grid_row, grid_col) in ground_truth_defective_patches | |
| # Classify patch | |
| if pred_defective and gt_defective: | |
| status = "TP" | |
| elif pred_defective and not gt_defective: | |
| status = "FP" | |
| elif not pred_defective and gt_defective: | |
| status = "FN" | |
| else: | |
| status = "TN" | |
| return { | |
| 'status': status, | |
| 'anomaly_pixels': anomaly_pixels, | |
| 'pred_score': float(np.mean(pred_patch)), | |
| 'gt_present': bool(gt_defective) | |
| } | |
| def classify_patches(self, pred_mask, ground_truth_defective_patches, patch_size, | |
| anomaly_binary_threshold=5, anomaly_pixel_num_threshold=10): | |
| """Classify patches for a single image into TP/FP/FN/TN categories using annotations | |
| This orchestrator function uses the separate patch coordinate generation, | |
| extraction, and classification functions to follow single responsibility principle. | |
| Args: | |
| pred_mask: Prediction mask as numpy array | |
| ground_truth_defective_patches: Set of (row, col) tuples representing defective patches | |
| patch_size: Size of patches to extract | |
| anomaly_binary_threshold: Threshold for creating binary mask (0-255) | |
| anomaly_pixel_num_threshold: Minimum pixels needed for defect classification | |
| """ | |
| h, w = pred_mask.shape | |
| # Generate coordinates for all patches using pure function | |
| patch_coordinates = self.generate_patch_coordinates(h, w, patch_size) | |
| patch_results = [] | |
| for grid_row, grid_col, coords_8_values in patch_coordinates: | |
| # Extract prediction patch using pure function | |
| pred_patch = self.extract_patch(pred_mask, coords_8_values) | |
| # Classify single patch using annotation-based ground truth | |
| classification_result = self.classify_single_patch( | |
| pred_patch, grid_row, grid_col, ground_truth_defective_patches, | |
| anomaly_binary_threshold, anomaly_pixel_num_threshold | |
| ) | |
| # Add grid position and coordinate information | |
| # Note: coords_8_values[0:2] gives top-left corner (x1, y1) | |
| x1, y1 = coords_8_values[0], coords_8_values[1] | |
| patch_result = { | |
| 'grid_row': int(grid_row), | |
| 'grid_col': int(grid_col), | |
| 'coords': (int(x1), int(y1)), | |
| **classification_result # Unpack status, anomaly_pixels, pred_score, gt_present | |
| } | |
| patch_results.append(patch_result) | |
| return patch_results | |
| def save_evaluation_images_for_category(self, anomaly_maps: dict, annotation_dir: str, | |
| category: str, image_paths: list, | |
| npy_cache_dir: str = None, | |
| patch_size: int = 256, | |
| anomaly_binary_threshold: int = 5, | |
| anomaly_pixel_num_threshold: int = 10, | |
| overlay_alpha: float = 0.8, grid_thickness: int = 1): | |
| """Save evaluation images for a single category using existing confusion matrix pipeline""" | |
| self.log_message('\nSaving evaluation images...') | |
| # Use the arithmetic anomaly map as the primary prediction | |
| predictions = anomaly_maps['anomaly_arithmetic'] | |
| # Load ground truth annotations if available | |
| ground_truth_map = self.load_ground_truth_map(annotation_dir) if annotation_dir else {} | |
| # Load NPY manifest if npy_cache_dir is provided | |
| npy_manifest = {} | |
| if npy_cache_dir: | |
| manifest_path = os.path.join(npy_cache_dir, 'fundamentals_manifest.csv') | |
| if os.path.exists(manifest_path): | |
| import csv | |
| with open(manifest_path, 'r') as f: | |
| reader = csv.DictReader(f) | |
| for row in reader: | |
| # Map both original_file_path and full_file_path to npy_filename for flexible lookup | |
| # Handles both simple names ("000") and full paths ("wood/test/hole/000.png") | |
| original_path = row['original_file_path'] | |
| full_path = row.get('full_file_path', '') | |
| npy_filename = row['npy_filename'] | |
| npy_full_path = os.path.join(npy_cache_dir, npy_filename) | |
| # Add multiple lookup keys for robust path matching | |
| npy_manifest[original_path] = npy_full_path | |
| if full_path: | |
| npy_manifest[full_path] = npy_full_path | |
| # Also add basename without extension | |
| full_basename = os.path.splitext(os.path.basename(full_path))[0] | |
| npy_manifest[full_basename] = npy_full_path | |
| self.log_message(f"Processing {len(predictions)} images for {category}...") | |
| # Process each image | |
| for idx, (pred_mask, image_path) in enumerate(zip(predictions, image_paths)): | |
| # Ensure prediction mask is numpy array | |
| if isinstance(pred_mask, torch.Tensor): | |
| pred_mask = pred_mask.detach().cpu().numpy() | |
| # Get ground truth defective patches for this image | |
| ground_truth_defective_patches = self.get_ground_truth_for_image( | |
| ground_truth_map, image_path | |
| ) | |
| # Classify patches for this image using annotations | |
| patch_results = self.classify_patches( | |
| pred_mask, ground_truth_defective_patches, patch_size, | |
| anomaly_binary_threshold, anomaly_pixel_num_threshold | |
| ) | |
| # Convert patch results to sets for visualization | |
| predicted_defective_set = set() | |
| overlapping_set = set() | |
| for result in patch_results: | |
| grid_row = result['grid_row'] | |
| grid_col = result['grid_col'] | |
| status = result['status'] | |
| if status in ['TP', 'FP']: # Predicted as defective | |
| predicted_defective_set.add((grid_row, grid_col)) | |
| if status == 'TP': # True positive = overlapping | |
| overlapping_set.add((grid_row, grid_col)) | |
| # Create combined anomaly maps for saving | |
| combined_anomaly_maps = { | |
| 'arithmetic': anomaly_maps['anomaly_arithmetic'][idx] if 'anomaly_arithmetic' in anomaly_maps else pred_mask, | |
| 'geometric': anomaly_maps['anomaly_geometric'][idx] if 'anomaly_geometric' in anomaly_maps else pred_mask, | |
| 'latent': anomaly_maps['latent_discrepancy'][idx] if 'latent_discrepancy' in anomaly_maps else pred_mask, | |
| 'image': anomaly_maps['image_discrepancy'][idx] if 'image_discrepancy' in anomaly_maps else pred_mask | |
| } | |
| # Save evaluation images using the real image path | |
| try: | |
| self.save_evaluation_images( | |
| image_path=image_path, | |
| anomaly_maps=combined_anomaly_maps, | |
| patch_results=patch_results, | |
| predicted_defective_set=predicted_defective_set, | |
| ground_truth_defective_set=ground_truth_defective_patches, | |
| overlapping_set=overlapping_set, | |
| npy_manifest=npy_manifest, | |
| npy_cache_dir=npy_cache_dir, | |
| patch_size=patch_size, | |
| anomaly_binary_threshold=anomaly_binary_threshold, | |
| overlay_alpha=overlay_alpha, | |
| grid_thickness=grid_thickness, | |
| save_raw_anomaly_maps=getattr(self, 'save_raw_anomaly_maps', False), | |
| save_image_variants=getattr(self, 'save_image_variants', None), | |
| save_colormaps=getattr(self, 'save_colormaps', None), | |
| save_normalization=getattr(self, 'save_normalization', 'minmax'), | |
| normal_center=getattr(self, 'normal_center', 125.0), | |
| save_binary_thresholds=getattr(self, 'save_binary_thresholds', None) | |
| ) | |
| except Exception as e: | |
| print(f"Warning: Could not save images for {image_path}: {e}") | |
| self.log_message(f"Evaluation images saved for category: {category}") | |
| def save_evaluation_images(self, image_path: str, anomaly_maps: dict, patch_results: list, | |
| predicted_defective_set: set, ground_truth_defective_set: set, | |
| overlapping_set: set, npy_manifest: dict = None, | |
| npy_cache_dir: str = None, | |
| patch_size: int = 256, | |
| anomaly_binary_threshold: int = 5, | |
| overlay_alpha: float = 0.8, grid_thickness: int = 1, | |
| save_raw_anomaly_maps: bool = False, | |
| save_image_variants: Optional[List[str]] = None, | |
| save_colormaps: Optional[List[str]] = None, | |
| save_normalization: str = 'minmax', | |
| normal_center: float = 125.0, | |
| save_binary_thresholds: Optional[List[float]] = None): | |
| """Save all evaluation images in organized subfolder structure""" | |
| import os | |
| from PIL import Image as PILImage | |
| if not self.checkpoint_manager: | |
| return | |
| # Determine image status from patch results | |
| image_status = determine_image_status(patch_results) | |
| safe_name = path_to_safe_filename(image_path) | |
| # Load original image from NPY cache or disk | |
| original_img = None | |
| # Try loading from NPY cache first if manifest is provided | |
| if npy_manifest: | |
| # Try multiple lookup strategies for robust path matching | |
| lookup_keys = [ | |
| image_path, # Full path as provided | |
| os.path.splitext(os.path.basename(image_path))[0], # Basename without extension | |
| os.path.basename(image_path) # Basename with extension | |
| ] | |
| npy_path = None | |
| for key in lookup_keys: | |
| if key in npy_manifest: | |
| npy_path = npy_manifest[key] | |
| break | |
| if npy_path: | |
| try: | |
| # Handle both .npy and .npz files | |
| if npy_path.endswith('.npz'): | |
| # Load NPZ archive and extract 'original_image' key | |
| with np.load(npy_path) as data: | |
| if 'original_image' in data: | |
| original_tensor = data['original_image'] # Shape: (C, H, W), dtype: float32 | |
| else: | |
| raise KeyError(f"'original_image' key not found in {npy_path}") | |
| else: | |
| # Load regular NPY file | |
| original_tensor = np.load(npy_path) # Shape: (C, H, W), dtype: float32 | |
| # Convert from (C, H, W) to (H, W, C) and rescale to [0, 255] | |
| original_img = np.transpose(original_tensor, (1, 2, 0)) # (H, W, C) | |
| # Rescale from [-1, 1] or [0, 1] to [0, 255] with proper rounding | |
| if original_img.min() < 0: # Likely in [-1, 1] range | |
| original_img = np.clip(np.round((original_img + 1.0) * 127.5), 0, 255).astype(np.uint8) | |
| else: # Likely in [0, 1] range | |
| original_img = np.clip(np.round(original_img * 255.0), 0, 255).astype(np.uint8) | |
| except Exception as e: | |
| print(f"Warning: Could not load NPY/NPZ image from {npy_path}: {e}") | |
| # Fallback to loading from disk if NPY failed or not available | |
| if original_img is None: | |
| try: | |
| if os.path.exists(image_path): | |
| original_img = np.array(PILImage.open(image_path).convert('RGB')) | |
| except Exception as e: | |
| print(f"Warning: Could not load image from disk {image_path}: {e}") | |
| # Final fallback: Search NPZ files in cache directory to resurrect original image | |
| if original_img is None and npy_cache_dir and os.path.exists(npy_cache_dir): | |
| try: | |
| import glob | |
| # Try to find any NPZ file in the cache directory | |
| npz_files = glob.glob(os.path.join(npy_cache_dir, '*.npz')) | |
| # Try to match by image index or basename | |
| image_basename = os.path.splitext(os.path.basename(image_path))[0] | |
| for npz_file in npz_files: | |
| try: | |
| with np.load(npz_file) as data: | |
| if 'original_image' in data: | |
| # Check if this might be the right file by examining the filename | |
| npz_basename = os.path.splitext(os.path.basename(npz_file))[0] | |
| # Match if the image basename appears in the NPZ filename | |
| if image_basename in npz_basename or npz_basename.startswith(image_basename): | |
| original_tensor = data['original_image'] | |
| # Convert from (C, H, W) to (H, W, C) and rescale to [0, 255] | |
| original_img = np.transpose(original_tensor, (1, 2, 0)) | |
| # Rescale from [-1, 1] or [0, 1] to [0, 255] with proper rounding | |
| if original_img.min() < 0: | |
| original_img = np.clip(np.round((original_img + 1.0) * 127.5), 0, 255).astype(np.uint8) | |
| else: | |
| original_img = np.clip(np.round(original_img * 255.0), 0, 255).astype(np.uint8) | |
| print(f"✓ Resurrected original image from cache: {npz_file}") | |
| break | |
| except Exception as inner_e: | |
| continue # Try next NPZ file | |
| if original_img is None: | |
| raise RuntimeError(f"Could not find original image in NPZ cache for {image_path}") | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to resurrect original image from cache for {image_path}: {e}") | |
| # Save original image (without any overlays or markings) | |
| self._save_to_subfolders(original_img, safe_name, image_status, 'original', 'original.png') | |
| # Save marked image (with patch rectangles) | |
| marked_img = draw_patch_rectangles_on_image( | |
| original_img, predicted_defective_set, | |
| ground_truth_defective_set, overlapping_set, | |
| patch_size, grid_thickness | |
| ) | |
| self._save_to_subfolders(marked_img, safe_name, image_status, 'marked', 'marked.png') | |
| # Set defaults for variant saving | |
| if save_image_variants is None: | |
| save_image_variants = ['continuous', 'binary'] | |
| if save_colormaps is None: | |
| save_colormaps = ['jet'] # Default for eval-evaluator | |
| if save_binary_thresholds is None: | |
| save_binary_thresholds = [anomaly_binary_threshold] | |
| # Handle 'all' variant option | |
| if 'all' in save_image_variants: | |
| save_image_variants = ['continuous', 'binary', 'grayscale', 'absolute'] | |
| # Check if we should use new variant system | |
| use_variant_system = (len(save_image_variants) > 2 or | |
| len(save_colormaps) > 1 or | |
| len(save_binary_thresholds) > 1 or | |
| 'grayscale' in save_image_variants or | |
| 'absolute' in save_image_variants) | |
| # Save each type of anomaly map | |
| for map_type in ['arithmetic', 'geometric', 'latent', 'image']: | |
| if map_type not in anomaly_maps: | |
| continue | |
| map_data = anomaly_maps[map_type] | |
| # Ensure it's a numpy array | |
| if isinstance(map_data, torch.Tensor): | |
| map_data = map_data.detach().cpu().numpy() | |
| if use_variant_system: | |
| # Use new variant system | |
| from dioodmi.utils.visualization_variants import create_anomaly_visualization_variants | |
| # Create variants | |
| variants = create_anomaly_visualization_variants( | |
| map_data, | |
| variants=save_image_variants, | |
| colormaps=save_colormaps, | |
| normalization=save_normalization, | |
| normal_center=normal_center, | |
| thresholds=save_binary_thresholds | |
| ) | |
| # Save each variant | |
| for variant_name, variant_img in variants.items(): | |
| # Save to status folder | |
| self._save_to_subfolders( | |
| variant_img, safe_name, image_status, | |
| f'anomaly_maps/{map_type}', f'{map_type}__{variant_name}.png' | |
| ) | |
| else: | |
| # Use legacy system (backward compatible) | |
| # Save raw anomaly map | |
| anomaly_img = create_anomaly_visualization( | |
| map_data, is_binary=False | |
| ) | |
| self._save_to_subfolders( | |
| anomaly_img, safe_name, image_status, | |
| f'anomaly_maps/{map_type}', f'{map_type}.png' | |
| ) | |
| # Save binary version and overlays for arithmetic/geometric | |
| if map_type in ['arithmetic', 'geometric']: | |
| # Save binary version | |
| binary_img = create_anomaly_visualization( | |
| map_data, is_binary=True, threshold=anomaly_binary_threshold | |
| ) | |
| self._save_to_subfolders( | |
| binary_img, safe_name, image_status, | |
| f'binary/{map_type}', f'{map_type}_binary.png' | |
| ) | |
| # Save overlay | |
| overlay_img = create_anomaly_overlay( | |
| original_img, map_data, alpha=overlay_alpha, | |
| threshold=anomaly_binary_threshold | |
| ) | |
| self._save_to_subfolders( | |
| overlay_img, safe_name, image_status, | |
| f'overlays/{map_type}', f'{map_type}_overlay.png' | |
| ) | |
| # Save marked overlay | |
| marked_overlay = draw_patch_rectangles_on_image( | |
| overlay_img, predicted_defective_set, | |
| ground_truth_defective_set, overlapping_set, | |
| patch_size, grid_thickness | |
| ) | |
| self._save_to_subfolders( | |
| marked_overlay, safe_name, image_status, | |
| 'overlays/marked', f'{map_type}_marked_overlay.png' | |
| ) | |
| # Save raw anomaly maps as NPY files if enabled | |
| if save_raw_anomaly_maps and self.checkpoint_manager: | |
| try: | |
| npy_save_dir = Path(self.checkpoint_manager.results_dir) / "anomaly_maps_npy" | |
| npy_save_dir.mkdir(parents=True, exist_ok=True) | |
| # Save raw anomaly maps (before any visualization) | |
| for map_type in ['arithmetic', 'geometric', 'latent', 'image']: | |
| if map_type in anomaly_maps: | |
| # Ensure it's a numpy array | |
| map_data = anomaly_maps[map_type] | |
| if isinstance(map_data, torch.Tensor): | |
| map_data = map_data.detach().cpu().numpy() | |
| # Save raw array | |
| np.save(npy_save_dir / f"{safe_name}__{map_type}.npy", map_data) | |
| # Save binary version if applicable | |
| if map_type in ['arithmetic', 'geometric']: | |
| binary_map = (map_data > anomaly_binary_threshold).astype(np.float32) | |
| np.save(npy_save_dir / f"{safe_name}__{map_type}_binary.npy", binary_map) | |
| except Exception as e: | |
| print(f"Warning: Failed to save raw anomaly maps for {image_path}: {e}") | |
| # Create and save concatenated 4-panel images (original, overlay, anomaly_map, binary) | |
| try: | |
| for map_type in ['arithmetic', 'geometric']: | |
| if map_type not in anomaly_maps: | |
| continue | |
| map_data = anomaly_maps[map_type] | |
| if isinstance(map_data, torch.Tensor): | |
| map_data = map_data.detach().cpu().numpy() | |
| # Create the 4 panels | |
| # 1. Original image | |
| panel_original = original_img | |
| # 2. Overlay image (original + anomaly map overlay) | |
| panel_overlay = create_anomaly_overlay( | |
| original_img, map_data, alpha=overlay_alpha, | |
| threshold=anomaly_binary_threshold | |
| ) | |
| # 3. Anomaly map (continuous with colormap) | |
| panel_anomaly_map = create_anomaly_visualization( | |
| map_data, is_binary=False | |
| ) | |
| # 4. Binary image | |
| panel_binary = create_anomaly_visualization( | |
| map_data, is_binary=True, threshold=anomaly_binary_threshold | |
| ) | |
| # Ensure all panels have the same height | |
| h, w, _ = panel_original.shape | |
| panel_overlay = self._resize_if_needed(panel_overlay, h, w) | |
| panel_anomaly_map = self._resize_if_needed(panel_anomaly_map, h, w) | |
| panel_binary = self._resize_if_needed(panel_binary, h, w) | |
| # Concatenate horizontally: [overlay | original | binary | anomaly_map] | |
| concatenated_img = np.concatenate([ | |
| panel_overlay, panel_original, panel_binary, panel_anomaly_map | |
| ], axis=1) | |
| # Save concatenated image | |
| self._save_to_subfolders( | |
| concatenated_img, safe_name, image_status, | |
| 'concatenated', f'{map_type}_concatenated.png' | |
| ) | |
| except Exception as e: | |
| print(f"Warning: Failed to create concatenated images for {image_path}: {e}") | |
| # Save evaluation results JSON | |
| self._save_evaluation_json(image_path, patch_results, safe_name, patch_size) | |
| def _resize_if_needed(self, image: np.ndarray, target_h: int, target_w: int) -> np.ndarray: | |
| """Resize image if dimensions don't match target size""" | |
| from PIL import Image as PILImage | |
| h, w = image.shape[:2] | |
| if h != target_h or w != target_w: | |
| pil_img = PILImage.fromarray(image) | |
| pil_img = pil_img.resize((target_w, target_h), PILImage.LANCZOS) | |
| return np.array(pil_img) | |
| return image | |
| def _save_to_subfolders(self, image_array: np.ndarray, base_name: str, | |
| status: str, subfolder: str, suffix: str): | |
| """Helper to save image to both status and image_level folders""" | |
| from PIL import Image as PILImage | |
| # Save to status-specific subfolder (create directory on-demand) | |
| status_folder_path = self.checkpoint_manager.get_status_folder(status) / subfolder | |
| status_folder_path.mkdir(parents=True, exist_ok=True) | |
| status_path = status_folder_path / f"{base_name}_{suffix}" | |
| PILImage.fromarray(image_array).save(status_path) | |
| # Also save to image_level (without status classification) | |
| image_level_folder_path = self.checkpoint_manager.get_image_level_folder() / subfolder | |
| image_level_folder_path.mkdir(parents=True, exist_ok=True) | |
| image_level_path = image_level_folder_path / f"{base_name}_{suffix}" | |
| PILImage.fromarray(image_array).save(image_level_path) | |
| def _save_evaluation_json(self, image_path: str, patch_results: list, | |
| safe_name: str, patch_size: int): | |
| """Save evaluation results as JSON""" | |
| import json | |
| patch_analysis = [] | |
| for result in patch_results: | |
| x, y = result.get('coords', (0, 0)) | |
| patch_analysis.append({ | |
| "grid_row": result.get('grid_row', 0), | |
| "grid_col": result.get('grid_col', 0), | |
| "anomaly_pixels": result.get('anomaly_pixels', 0), | |
| "pred_score": result.get('pred_score', 0.0), | |
| "status": result.get('status', 'TN') | |
| }) | |
| # Ensure evaluation results directory exists | |
| self.checkpoint_manager.evaluation_results_dir.mkdir(parents=True, exist_ok=True) | |
| result_filename = f"{safe_name}__evaluation.json" | |
| result_path = self.checkpoint_manager.evaluation_results_dir / result_filename | |
| evaluation_result = { | |
| "image_path": image_path, | |
| "patch_analysis": patch_analysis, | |
| "grid_size": patch_size | |
| } | |
| with open(result_path, 'w') as f: | |
| json.dump(evaluation_result, f, indent=2) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| # -*- coding: utf-8 -*- | |
| """HAEDOSA Automation Engine CLI for DiOodMi. | |
| Unified command-line interface orchestrating DiOodMi operations. | |
| Follows HAEDOSA MS12 automation standards with export-based architecture. | |
| HAE provides: | |
| - Orchestration of core verbs (train, eval, reconstruct) | |
| - Workflow automation (benchmarking, multi-category training) | |
| - Example management | |
| - Command export (shows standalone equivalents) | |
| """ | |
| import os | |
| import subprocess | |
| import sys | |
| from pathlib import Path | |
| # Configure UTF-8 encoding for all I/O operations (Windows compatibility) | |
| # This must be set before any console output | |
| if sys.platform.startswith('win'): | |
| os.environ.setdefault('PYTHONIOENCODING', 'utf-8') | |
| # Set Windows console code page to UTF-8 | |
| try: | |
| import ctypes | |
| kernel32 = ctypes.windll.kernel32 | |
| kernel32.SetConsoleOutputCP(65001) # UTF-8 code page | |
| kernel32.SetConsoleCP(65001) | |
| except Exception: | |
| pass # Fallback if ctypes fails | |
| # Reconfigure stdout/stderr to UTF-8 | |
| try: | |
| sys.stdout.reconfigure(encoding='utf-8', errors='replace') | |
| sys.stderr.reconfigure(encoding='utf-8', errors='replace') | |
| except AttributeError: | |
| # Python < 3.7 fallback | |
| import io | |
| if not isinstance(sys.stdout, io.TextIOWrapper): | |
| sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8', errors='replace') | |
| if not isinstance(sys.stderr, io.TextIOWrapper): | |
| sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8', errors='replace') | |
| import click | |
| from click.core import Command | |
| from rich.console import Console | |
| from rich.table import Table | |
| from rich.progress import Progress, SpinnerColumn, TextColumn | |
| from .export import show_standalone_alternatives | |
| # Configure Console with UTF-8 support for Windows | |
| # Disable legacy Windows rendering to avoid Unicode encoding errors with emojis | |
| console = Console(legacy_windows=False, file=sys.stdout) | |
| def preprocess_latent_sizes_args(args): | |
| """Preprocess command line arguments to handle --latent_sizes with multiple space-separated values. | |
| Converts: --latent_sizes 16 32 64 --ncls ... | |
| To: --latent_sizes 16 --latent_sizes 32 --latent_sizes 64 --ncls ... | |
| """ | |
| result = [] | |
| i = 0 | |
| while i < len(args): | |
| if args[i] == '--latent_sizes' and i + 1 < len(args): | |
| # Found --latent_sizes, collect all following non-option args | |
| i += 1 | |
| # Collect all following non-option arguments until we hit another option | |
| while i < len(args) and not args[i].startswith('--'): | |
| result.append('--latent_sizes') | |
| result.append(args[i]) | |
| i += 1 | |
| # Don't increment i here, as we want to process the next option | |
| continue | |
| result.append(args[i]) | |
| i += 1 | |
| return result | |
| def parse_directions(ctx, param, value): | |
| """Callback to parse directions option, handling both space-separated string and multiple flags. | |
| With multiple=True, Click collects all following non-option arguments into a tuple. | |
| This callback processes that tuple to handle space-separated values. | |
| """ | |
| if not value: | |
| return value | |
| # With multiple=True, Click should collect all following non-option args as a tuple | |
| # e.g., --directions h v diag becomes ('h', 'v', 'diag') | |
| if isinstance(value, tuple): | |
| # Filter out empty strings and strip whitespace | |
| result = tuple(d.strip() for d in value if d and d.strip()) | |
| return result if result else value | |
| # If it's a string (from quoted input like --directions "h v diag"), split it | |
| if isinstance(value, str): | |
| result = tuple(d.strip() for d in value.split() if d.strip()) | |
| return result if result else (value,) | |
| return value | |
| @click.group() | |
| @click.version_option() | |
| def main(): | |
| """HAEDOSA Automation Engine for DiOodMi. | |
| Orchestrates diffusion-based anomaly detection workflows. | |
| """ | |
| pass | |
| @main.command('train') | |
| @click.option('--model', type=click.Choice(['ddpm', 'ldm', 'decodiff']), | |
| default='decodiff', help='Model type to train') | |
| @click.option('--dataset', type=click.Choice(['mvtec', 'visa', 'custom']), | |
| default='mvtec', help='Dataset to use') | |
| @click.option('--object-category', '--category', 'category', | |
| help='Dataset category (e.g., bottle, wood, all)') | |
| @click.option('--epochs', type=int, default=800, help='Number of training epochs') | |
| @click.option('--batch-size', type=int, help='Global batch size') | |
| @click.option('--data-dir', type=click.Path(exists=True), | |
| help='Path to dataset directory') | |
| @click.option('--csv-split-file', type=click.Path(exists=True), | |
| help='CSV file defining custom dataset splits') | |
| @click.option('--output-dir', type=click.Path(), default='./results', | |
| help='Output directory for checkpoints') | |
| @click.option('--resume-dir', type=click.Path(exists=True), | |
| help='Directory containing checkpoint to resume from') | |
| @click.option('--model-size', type=click.Choice(['UNet_XS', 'UNet_S', 'UNet_M', 'UNet_L', 'UNet_XL']), | |
| help='Model size (for DeCo-Diff)') | |
| @click.option('--image-size', type=int, help='Image size') | |
| @click.option('--crop-size', type=int, help='Center crop size') | |
| @click.option('--center-crop/--no-center-crop', default=None, | |
| help='Use center crop (default: auto-detect based on image/crop size)') | |
| @click.option('--num-workers', type=int, default=0, | |
| help='Number of data loader workers') | |
| @click.option('--image-loading-strategy', type=click.Choice(['resize_first', 'keep_original', 'adaptive']), | |
| help='Image loading strategy (resize_first, keep_original, or adaptive)') | |
| @click.option('--config', type=click.Path(exists=True), | |
| help='Config file path (not yet implemented)') | |
| @click.option('--distributed/--no-distributed', default=False, | |
| help='Use distributed training via torchrun') | |
| @click.option('--gpus', type=int, default=1, | |
| help='Number of GPUs for distributed training') | |
| @click.option('--dry-run/--no-dry-run', default=False, | |
| help='Show commands without executing') | |
| @click.option('--model-name', type=str, | |
| help='Model name for checkpoint directory') | |
| @click.option('--ckpt-every', type=int, | |
| help='Save checkpoint every N epochs') | |
| @click.option('--bootstrap-samples', type=int, | |
| help='Number of bootstrap samples') | |
| @click.option('--bootstrap-seed', type=int, | |
| help='Bootstrap random seed') | |
| @click.option('--first-n', type=int, | |
| help='Take first N samples') | |
| @click.option('--last-n', type=int, | |
| help='Take last N samples') | |
| @click.option('--range-start', type=int, | |
| help='Start index for range sampling') | |
| @click.option('--range-count', type=int, | |
| help='Count for range sampling') | |
| @click.option('--random-n', type=int, | |
| help='Random N samples (without replacement)') | |
| @click.option('--random-n-with-replacement', type=int, | |
| help='Random N samples (with replacement)') | |
| @click.option('--random-seed', type=int, | |
| help='Random seed for sampling') | |
| def train(model, dataset, category, epochs, batch_size, data_dir, csv_split_file, output_dir, | |
| resume_dir, model_size, image_size, crop_size, center_crop, num_workers, | |
| image_loading_strategy, config, distributed, gpus, dry_run, | |
| model_name, ckpt_every, bootstrap_samples, bootstrap_seed, | |
| first_n, last_n, range_start, range_count, random_n, random_n_with_replacement, random_seed): | |
| """Train a diffusion model for anomaly detection. | |
| Orchestrates dioodmi-train with appropriate configuration. | |
| Examples: | |
| hae train --model decodiff --dataset mvtec --category bottle | |
| hae train --model decodiff --distributed --gpus 2 --dry-run | |
| """ | |
| console.print(f"[bold green]🏋️ Training {model.upper()} on {dataset}[/bold green]\n") | |
| # Build command (platform-specific) | |
| # Windows: use system Python (py -3.11 -m) to access system-installed packages | |
| # Linux/uv: use dioodmi-train command from uv environment | |
| if sys.platform.startswith('win'): | |
| cmd = ["py", "-3.11", "-m", "dioodmi.cli.train"] | |
| else: | |
| cmd = ["dioodmi-train"] | |
| cmd.extend(["--model", model, "--dataset", dataset]) | |
| if category: | |
| cmd.extend(["--object_category", category]) | |
| if epochs: | |
| cmd.extend(["--epochs", str(epochs)]) | |
| if batch_size: | |
| cmd.extend(["--global_batch_size", str(batch_size)]) | |
| if data_dir: | |
| cmd.extend(["--data_dir", data_dir]) | |
| if csv_split_file: | |
| cmd.extend(["--csv_split_file", csv_split_file]) | |
| if output_dir: | |
| cmd.extend(["--output_dir", output_dir]) | |
| if resume_dir: | |
| cmd.extend(["--resume_dir", resume_dir]) | |
| if model_size: | |
| cmd.extend(["--model_size", model_size]) | |
| if image_size: | |
| cmd.extend(["--image_size", str(image_size)]) | |
| if crop_size: | |
| cmd.extend(["--crop_size", str(crop_size)]) | |
| if center_crop is not None: | |
| cmd.extend(["--center_crop", "true" if center_crop else "false"]) | |
| if num_workers is not None: | |
| cmd.extend(["--num_workers", str(num_workers)]) | |
| if image_loading_strategy: | |
| cmd.extend(["--image_loading_strategy", image_loading_strategy]) | |
| if model_name: | |
| cmd.extend(["--model_name", model_name]) | |
| if ckpt_every is not None: | |
| cmd.extend(["--ckpt_every", str(ckpt_every)]) | |
| if bootstrap_samples is not None: | |
| cmd.extend(["--bootstrap-samples", str(bootstrap_samples)]) | |
| if bootstrap_seed is not None: | |
| cmd.extend(["--bootstrap-seed", str(bootstrap_seed)]) | |
| if first_n is not None: | |
| cmd.extend(["--first-n", str(first_n)]) | |
| if last_n is not None: | |
| cmd.extend(["--last-n", str(last_n)]) | |
| if range_start is not None: | |
| cmd.extend(["--range-start", str(range_start)]) | |
| if range_count is not None: | |
| cmd.extend(["--range-count", str(range_count)]) | |
| if random_n is not None: | |
| cmd.extend(["--random-n", str(random_n)]) | |
| if random_n_with_replacement is not None: | |
| cmd.extend(["--random-n-with-replacement", str(random_n_with_replacement)]) | |
| if random_seed is not None: | |
| cmd.extend(["--random-seed", str(random_seed)]) | |
| # Show standalone alternatives | |
| show_standalone_alternatives(cmd, f"Train {model.upper()} on {dataset}" + (f"/{category}" if category else "")) | |
| if dry_run: | |
| console.print("[yellow]Dry run - not executing[/yellow]") | |
| return | |
| # Execute command | |
| console.print("[bold]Running:[/bold]", " ".join(cmd)) | |
| console.print() | |
| if distributed: | |
| # Use torchrun for distributed training | |
| torchrun_cmd = [ | |
| "torchrun", | |
| f"--nproc_per_node={gpus}", | |
| "--nnodes=1", | |
| "--node_rank=0" | |
| ] + cmd | |
| console.print(f"[yellow]Using distributed training on {gpus} GPUs[/yellow]") | |
| console.print(f"[dim]Command: {' '.join(torchrun_cmd)}[/dim]\n") | |
| result = subprocess.run(torchrun_cmd) | |
| else: | |
| result = subprocess.run(cmd) | |
| sys.exit(result.returncode) | |
| @main.command('eval') | |
| # Config file option (FIRST - allows loading all parameters from JSON) | |
| @click.option('--config', type=click.Path(exists=True), | |
| help='JSON config file (base + environment overrides)') | |
| # Core parameters | |
| @click.option('--model-path', '--checkpoint', 'model_path', | |
| type=click.Path(exists=True), | |
| help='Path to trained model checkpoint') | |
| @click.option('--dataset', type=click.Choice(['mvtec', 'mvtec2', 'sdp', 'sdp_orig', 'sdp_orig_test', | |
| 'sdp_ok_ng', 'sdp_ok_ng_crop', 'sdp_ok_ng_resize_1024', | |
| 'sdp_ok_ng_resize_2048', 'visa', 'pcb', 'avp', 'iris']), | |
| help='Dataset to evaluate') | |
| @click.option('--object-category', '--category', 'category', | |
| help='Dataset category') | |
| @click.option('--anomaly-class', | |
| help='Anomaly class filter') | |
| @click.option('--data-dir', type=click.Path(exists=True), | |
| help='Path to dataset directory') | |
| @click.option('--csv-split-file', | |
| help='Custom CSV split file (overrides dataset default)') | |
| # Model configuration | |
| @click.option('--model-size', type=click.Choice(['UNet_XS','UNet_S','UNet_M','UNet_L','UNet_XL']), | |
| help='Model size') | |
| @click.option('--vae-type', type=click.Choice(['ema', 'mse']), | |
| help='VAE type (ema or mse)') | |
| @click.option('--image-size', type=int, | |
| help='Image size (integer)') | |
| @click.option('--crop-size', type=int, | |
| help='Crop size') | |
| @click.option('--center-crop', is_flag=True, default=False, | |
| help='Enable center crop') | |
| # Evaluation parameters | |
| # Strategy parameters | |
| @click.option('--augmentation-strategy', | |
| type=click.Choice(['center_crop', 'center_crop_with_affine', | |
| 'random_crop', 'multi_stage', 'none']), | |
| help='Augmentation strategy for evaluation') | |
| @click.option('--image-loading-strategy', | |
| type=click.Choice(['resize_first', 'keep_original', 'adaptive']), | |
| help='Image loading strategy for evaluation') | |
| @click.option('--reverse-steps', type=int, | |
| help='Number of reverse diffusion steps') | |
| @click.option('--batch-size', type=int, | |
| help='Batch size for evaluation') | |
| @click.option('--num-workers', type=int, | |
| help='Number of DataLoader workers') | |
| # Anomaly detection parameters | |
| @click.option('--anomaly-threshold', type=float, | |
| help='Threshold for anomaly detection (0-255 range)') | |
| @click.option('--anomaly-min-area', type=int, | |
| help='Minimum area for detected anomaly regions') | |
| # Confusion matrix parameters | |
| @click.option('--enable-confusion-matrix', is_flag=True, default=None, | |
| help='Enable confusion matrix evaluation') | |
| @click.option('--patch-size', type=int, | |
| help='Patch size for confusion matrix') | |
| @click.option('--anomaly-binary-threshold', type=int, | |
| help='Binary threshold for anomaly detection') | |
| @click.option('--anomaly-pixel-num-threshold', type=int, | |
| help='Minimum pixel count for anomaly') | |
| @click.option('--annotation-dir', type=click.Path(), | |
| help='Directory containing annotation JSON files') | |
| # Output parameters | |
| @click.option('--save-plot-path', '--output-dir', 'output_dir', | |
| type=click.Path(), | |
| help='Results output directory') | |
| @click.option('--save-evaluation-images', is_flag=True, default=None, | |
| help='Save evaluation images in organized subfolders') | |
| @click.option('--save-reconstructions', is_flag=True, default=None, | |
| help='Save reconstructed images (VAE and Deco-Diff) as PNG') | |
| @click.option('--save-reconstructions-npy', is_flag=True, default=None, | |
| help='Save raw reconstruction arrays as NPY files') | |
| @click.option('--save-raw-anomaly-maps', is_flag=True, default=None, | |
| help='Save raw anomaly map arrays as NPY files (image-level)') | |
| @click.option('--save-image-variants', type=str, multiple=True, | |
| help='Which image variants to save (can specify multiple: continuous, binary, grayscale, absolute, all)') | |
| @click.option('--save-colormaps', type=str, multiple=True, | |
| help='Colormaps to use for continuous maps (can specify multiple: jet, hot, viridis, plasma, magma, inferno, turbo)') | |
| @click.option('--save-normalization', type=click.Choice(['minmax', 'zscore', 'percentile', 'raw']), | |
| help='Normalization method for visualization') | |
| @click.option('--normal-center', type=float, | |
| help='Center value for grayscale normalization (default: 125.0)') | |
| @click.option('--save-binary-thresholds', type=float, multiple=True, | |
| help='Threshold values for binary maps (can specify multiple, e.g., 5.0 10.0 15.0)') | |
| @click.option('--save-intermediate-images', is_flag=True, default=None, | |
| help='Save intermediate images (orig_decoded, recon)') | |
| @click.option('--save-npy-files', is_flag=True, default=False, | |
| help='Save anomaly maps as .npy files') | |
| @click.option('--npy-save-dir', type=click.Path(), | |
| help='Directory to save .npy files') | |
| # Full image evaluation | |
| @click.option('--full-image-eval', is_flag=True, default=None, | |
| help='Evaluate whole image by tiling/cropping and stitching') | |
| @click.option('--tile-overlap', type=int, | |
| help='Overlap pixels between tiles for blending') | |
| @click.option('--tile-batch', type=int, | |
| help='How many tiles to process at once on GPU') | |
| # Shift augmentation parameters | |
| @click.option('--pad-px', type=int, | |
| help='Pixel padding for shift augmentation') | |
| @click.option('--stride', type=int, | |
| help='Stride for shift augmentation') | |
| @click.option('--directions', type=str, multiple=True, callback=parse_directions, | |
| help='Shift directions (space-separated: h v diag all rhombus, or use multiple --directions flags)') | |
| @click.option('--shift-method', type=str, | |
| help='Shift method (interpolate, mirror, etc.)') | |
| @click.option('--fuse-method', type=str, | |
| help='Fuse method (mean, pct25, etc.)') | |
| @click.option('--img-enlarge-px', type=int, | |
| help='Image enlargement in pixels') | |
| # TTA fundamentals saving | |
| @click.option('--save-tta-fundamentals', is_flag=True, default=False, | |
| help='Save per-shift TTA fundamentals (x0, encoded, image_samples, latent_samples)') | |
| @click.option('--tta-cache-dir', type=str, default=None, | |
| help='Directory to save TTA fundamentals NPY files and CSV manifest') | |
| # Dataset sampling options | |
| @click.option('--bootstrap-samples', type=int, default=None, | |
| help='Bootstrap sampling: number of samples (with replacement)') | |
| @click.option('--bootstrap-seed', type=int, default=42, | |
| help='Random seed for bootstrap sampling') | |
| @click.option('--first-n', type=int, default=None, | |
| help='Take first N samples from dataset') | |
| @click.option('--last-n', type=int, default=None, | |
| help='Take last N samples from dataset') | |
| @click.option('--range-start', type=int, default=None, | |
| help='Range sampling: start index (0-based)') | |
| @click.option('--range-count', type=int, default=10, | |
| help='Range sampling: number of samples (default: 10)') | |
| @click.option('--random-n', type=int, default=None, | |
| help='Random N samples without replacement') | |
| @click.option('--random-n-with-replacement', type=int, default=None, | |
| help='Random N samples with replacement') | |
| @click.option('--random-seed', type=int, default=42, | |
| help='Random seed for random sampling methods') | |
| # Utility | |
| @click.option('--dry-run/--no-dry-run', default=False, | |
| help='Show commands without executing') | |
| def eval(config, model_path, dataset, category, anomaly_class, data_dir, csv_split_file, | |
| model_size, vae_type, image_size, crop_size, center_crop, | |
| augmentation_strategy, image_loading_strategy, | |
| reverse_steps, batch_size, num_workers, | |
| anomaly_threshold, anomaly_min_area, | |
| enable_confusion_matrix, patch_size, anomaly_binary_threshold, | |
| anomaly_pixel_num_threshold, annotation_dir, | |
| output_dir, save_evaluation_images, save_reconstructions, save_reconstructions_npy, | |
| save_raw_anomaly_maps, save_image_variants, save_colormaps, save_normalization, | |
| normal_center, save_binary_thresholds, | |
| save_intermediate_images, save_npy_files, npy_save_dir, | |
| full_image_eval, tile_overlap, tile_batch, | |
| pad_px, stride, directions, shift_method, fuse_method, img_enlarge_px, | |
| save_tta_fundamentals, tta_cache_dir, | |
| bootstrap_samples, bootstrap_seed, first_n, last_n, range_start, range_count, | |
| random_n, random_n_with_replacement, random_seed, | |
| dry_run): | |
| """Evaluate a trained model on test data. | |
| Supports JSON config files with CLI parameter overrides. | |
| Examples: | |
| # With JSON config | |
| hae eval --config config/eval_250923_MVTec.json | |
| # Config + CLI overrides | |
| hae eval --config config/eval.json --batch-size 4 --num-workers 0 | |
| # Pure CLI (no config) | |
| hae eval --model-path models/bottle.pth --dataset mvtec --category bottle | |
| """ | |
| # 1. Load config if provided | |
| config_dict = {} | |
| if config: | |
| from dioodmi.cli.config import load_config_with_env, detect_environment | |
| from pathlib import Path | |
| config_dict = load_config_with_env(Path(config)) | |
| console.print(f"[cyan]📄 Loaded config:[/cyan] {config}") | |
| console.print(f" Environment: {detect_environment()}") | |
| # 2. Build parameter dictionary from CLI arguments | |
| cli_params = { | |
| 'model_path': model_path, | |
| 'dataset': dataset, | |
| 'object_category': category, | |
| 'anomaly_class': anomaly_class, | |
| 'data_dir': data_dir, | |
| 'csv_split_file': csv_split_file, | |
| 'model_size': model_size, | |
| 'vae_type': vae_type, | |
| 'image_size': image_size, | |
| 'crop_size': crop_size, | |
| 'center_crop': center_crop, | |
| 'augmentation_strategy': augmentation_strategy, | |
| 'image_loading_strategy': image_loading_strategy, | |
| 'reverse_steps': reverse_steps, | |
| 'batch_size': batch_size, | |
| 'num_workers': num_workers, | |
| 'anomaly_threshold': anomaly_threshold, | |
| 'anomaly_min_area': anomaly_min_area, | |
| 'enable_confusion_matrix': enable_confusion_matrix, | |
| 'patch_size': patch_size, | |
| 'anomaly_binary_threshold': anomaly_binary_threshold, | |
| 'anomaly_pixel_num_threshold': anomaly_pixel_num_threshold, | |
| 'annotation_dir': annotation_dir, | |
| 'save_plot_path': output_dir, | |
| 'save_evaluation_images': save_evaluation_images, | |
| 'annotation_dir': annotation_dir, | |
| 'save_reconstructions': save_reconstructions, | |
| 'save_reconstructions_npy': save_reconstructions_npy, | |
| 'save_raw_anomaly_maps': save_raw_anomaly_maps, | |
| 'save_image_variants': list(save_image_variants) if save_image_variants else None, | |
| 'save_colormaps': list(save_colormaps) if save_colormaps else None, | |
| 'save_normalization': save_normalization, | |
| 'normal_center': normal_center, | |
| 'save_binary_thresholds': list(save_binary_thresholds) if save_binary_thresholds else None, | |
| 'save_intermediate_images': save_intermediate_images, | |
| 'save_npy_files': save_npy_files, | |
| 'npy_save_dir': npy_save_dir, | |
| 'full_image_eval': full_image_eval, | |
| 'tile_overlap': tile_overlap, | |
| 'tile_batch': tile_batch, | |
| 'pad_px': pad_px, | |
| 'stride': stride, | |
| 'directions': directions, | |
| 'shift_method': shift_method, | |
| 'fuse_method': fuse_method, | |
| 'img_enlarge_px': img_enlarge_px, | |
| 'save_tta_fundamentals': save_tta_fundamentals, | |
| 'tta_cache_dir': tta_cache_dir, | |
| # Sampling parameters | |
| 'bootstrap_samples': bootstrap_samples, | |
| 'bootstrap_seed': bootstrap_seed, | |
| 'first_n': first_n, | |
| 'last_n': last_n, | |
| 'range_start': range_start, | |
| 'range_count': range_count, | |
| 'random_n': random_n, | |
| 'random_n_with_replacement': random_n_with_replacement, | |
| 'random_seed': random_seed, | |
| 'use_consolidated_npz': use_consolidated_npz, | |
| } | |
| # 3. Merge: config values first, then override with CLI args | |
| final_params = config_dict.copy() | |
| for key, value in cli_params.items(): | |
| if value is not None: # CLI arg was explicitly provided | |
| final_params[key] = value | |
| # 4. Validate required parameters | |
| if 'model_path' not in final_params or final_params['model_path'] is None: | |
| console.print("[red]❌ Error: --model-path required (or specify in config)[/red]") | |
| sys.exit(1) | |
| # Display evaluation info | |
| eval_dataset = final_params.get('dataset', 'mvtec') | |
| eval_category = final_params.get('object_category', '') | |
| console.print(f"[bold blue]📊 Evaluating model on {eval_dataset}" + | |
| (f"/{eval_category}" if eval_category else "") + "[/bold blue]\n") | |
| # 5. Build platform-specific command | |
| if sys.platform.startswith('win'): | |
| cmd = ["py", "-3.11", "-m", "dioodmi.cli.eval"] | |
| else: | |
| cmd = ["dioodmi-eval"] | |
| # 6. Add all parameters to command | |
| cmd.extend(["--model_path", str(final_params['model_path'])]) | |
| # Dataset parameters | |
| if 'dataset' in final_params: | |
| cmd.extend(["--dataset", final_params['dataset']]) | |
| if 'object_category' in final_params: | |
| cmd.extend(["--object_category", final_params['object_category']]) | |
| if 'anomaly_class' in final_params: | |
| cmd.extend(["--anomaly_class", final_params['anomaly_class']]) | |
| if 'data_dir' in final_params: | |
| cmd.extend(["--data_dir", final_params['data_dir']]) | |
| if 'csv_split_file' in final_params: | |
| cmd.extend(["--csv_split_file", final_params['csv_split_file']]) | |
| # Model configuration | |
| if 'model_size' in final_params: | |
| cmd.extend(["--model_size", final_params['model_size']]) | |
| if 'vae_type' in final_params: | |
| cmd.extend(["--vae_type", final_params['vae_type']]) | |
| if 'image_size' in final_params: | |
| cmd.extend(["--image_size", str(final_params['image_size'])]) | |
| if 'crop_size' in final_params: | |
| cmd.extend(["--crop_size", str(final_params['crop_size'])]) | |
| # Boolean flags (convert to true/false strings for dioodmi.cli.eval) | |
| if 'center_crop' in final_params and final_params['center_crop']: | |
| cmd.extend(["--center_crop", "true"]) | |
| # Strategy parameters | |
| if 'augmentation_strategy' in final_params and final_params['augmentation_strategy']: | |
| cmd.extend(["--augmentation_strategy", final_params['augmentation_strategy']]) | |
| if 'image_loading_strategy' in final_params and final_params['image_loading_strategy']: | |
| cmd.extend(["--image_loading_strategy", final_params['image_loading_strategy']]) | |
| if 'enable_confusion_matrix' in final_params and final_params['enable_confusion_matrix'] is not None: | |
| cmd.extend(["--enable_confusion_matrix", | |
| "true" if final_params['enable_confusion_matrix'] else "false"]) | |
| if 'save_evaluation_images' in final_params and final_params['save_evaluation_images'] is not None: | |
| cmd.extend(["--save_evaluation_images", | |
| "true" if final_params['save_evaluation_images'] else "false"]) | |
| if 'save_reconstructions' in final_params and final_params['save_reconstructions'] is not None: | |
| cmd.extend(["--save_reconstructions", | |
| "true" if final_params['save_reconstructions'] else "false"]) | |
| if 'save_reconstructions_npy' in final_params and final_params['save_reconstructions_npy'] is not None: | |
| cmd.extend(["--save_reconstructions_npy", | |
| "true" if final_params['save_reconstructions_npy'] else "false"]) | |
| if 'save_raw_anomaly_maps' in final_params and final_params['save_raw_anomaly_maps'] is not None: | |
| cmd.extend(["--save_raw_anomaly_maps", | |
| "true" if final_params['save_raw_anomaly_maps'] else "false"]) | |
| if 'save_image_variants' in final_params and final_params['save_image_variants']: | |
| cmd.extend(["--save-image-variants"] + [str(v) for v in final_params['save_image_variants']]) | |
| if 'save_colormaps' in final_params and final_params['save_colormaps']: | |
| cmd.extend(["--save-colormaps"] + [str(v) for v in final_params['save_colormaps']]) | |
| if 'save_normalization' in final_params and final_params['save_normalization']: | |
| cmd.extend(["--save-normalization", str(final_params['save_normalization'])]) | |
| if 'normal_center' in final_params and final_params['normal_center'] is not None: | |
| cmd.extend(["--normal-center", str(final_params['normal_center'])]) | |
| if 'save_binary_thresholds' in final_params and final_params['save_binary_thresholds']: | |
| cmd.extend(["--save-binary-thresholds"] + [str(v) for v in final_params['save_binary_thresholds']]) | |
| if 'save_intermediate_images' in final_params and final_params['save_intermediate_images'] is not None: | |
| cmd.extend(["--save_intermediate_images", | |
| "true" if final_params['save_intermediate_images'] else "false"]) | |
| if 'save_npy_files' in final_params and final_params['save_npy_files']: | |
| cmd.extend(["--save_npy_files", "true"]) | |
| if 'full_image_eval' in final_params and final_params['full_image_eval'] is not None: | |
| cmd.extend(["--full_image_eval", | |
| "true" if final_params['full_image_eval'] else "false"]) | |
| # Evaluation parameters | |
| if 'reverse_steps' in final_params: | |
| cmd.extend(["--reverse_steps", str(final_params['reverse_steps'])]) | |
| if 'batch_size' in final_params: | |
| cmd.extend(["--batch_size", str(final_params['batch_size'])]) | |
| if 'num_workers' in final_params: | |
| cmd.extend(["--num_workers", str(final_params['num_workers'])]) | |
| # Anomaly detection parameters | |
| if 'anomaly_threshold' in final_params: | |
| cmd.extend(["--anomaly_threshold", str(final_params['anomaly_threshold'])]) | |
| if 'anomaly_min_area' in final_params: | |
| cmd.extend(["--anomaly_min_area", str(final_params['anomaly_min_area'])]) | |
| # Confusion matrix parameters | |
| if 'patch_size' in final_params: | |
| cmd.extend(["--patch_size", str(final_params['patch_size'])]) | |
| if 'anomaly_binary_threshold' in final_params: | |
| cmd.extend(["--anomaly_binary_threshold", str(final_params['anomaly_binary_threshold'])]) | |
| if 'anomaly_pixel_num_threshold' in final_params: | |
| cmd.extend(["--anomaly_pixel_num_threshold", str(final_params['anomaly_pixel_num_threshold'])]) | |
| if 'annotation_dir' in final_params: | |
| cmd.extend(["--annotation_dir", final_params['annotation_dir']]) | |
| # Output parameters | |
| if 'save_plot_path' in final_params: | |
| cmd.extend(["--save_plot_path", final_params['save_plot_path']]) | |
| if 'npy_save_dir' in final_params: | |
| cmd.extend(["--npy_save_dir", final_params['npy_save_dir']]) | |
| # Full image evaluation parameters | |
| if 'tile_overlap' in final_params: | |
| cmd.extend(["--tile_overlap", str(final_params['tile_overlap'])]) | |
| if 'tile_batch' in final_params: | |
| cmd.extend(["--tile_batch", str(final_params['tile_batch'])]) | |
| # Shift augmentation parameters (only pass if pad_px is set, indicating TTA is enabled) | |
| if 'pad_px' in final_params and final_params['pad_px'] is not None: | |
| cmd.extend(["--pad_px", str(final_params['pad_px'])]) | |
| if 'img_enlarge_px' in final_params: | |
| cmd.extend(["--img_enlarge_px", str(final_params['img_enlarge_px'])]) | |
| if 'stride' in final_params: | |
| cmd.extend(["--stride", str(final_params['stride'])]) | |
| if 'directions' in final_params and final_params['directions']: | |
| # dioodmi-eval uses argparse nargs="+" so it expects: --directions h v diag | |
| directions = final_params['directions'] | |
| if isinstance(directions, (tuple, list)): | |
| cmd.extend(["--directions"] + list(directions)) | |
| else: | |
| # Handle string case (from config file): "h v diag" -> ["h", "v", "diag"] | |
| cmd.extend(["--directions"] + directions.split()) | |
| if 'shift_method' in final_params: | |
| cmd.extend(["--shift_method", final_params['shift_method']]) | |
| if 'fuse_method' in final_params: | |
| cmd.extend(["--fuse_method", final_params['fuse_method']]) | |
| # TTA fundamentals saving | |
| if 'save_tta_fundamentals' in final_params and final_params['save_tta_fundamentals']: | |
| cmd.append("--save_tta_fundamentals") | |
| if 'tta_cache_dir' in final_params and final_params['tta_cache_dir']: | |
| cmd.extend(["--tta_cache_dir", str(final_params['tta_cache_dir'])]) | |
| # Dataset sampling parameters | |
| if 'bootstrap_samples' in final_params and final_params['bootstrap_samples'] is not None: | |
| cmd.extend(["--bootstrap_samples", str(final_params['bootstrap_samples'])]) | |
| if 'bootstrap_seed' in final_params: | |
| cmd.extend(["--bootstrap_seed", str(final_params['bootstrap_seed'])]) | |
| if 'first_n' in final_params and final_params['first_n'] is not None: | |
| cmd.extend(["--first_n", str(final_params['first_n'])]) | |
| if 'last_n' in final_params and final_params['last_n'] is not None: | |
| cmd.extend(["--last_n", str(final_params['last_n'])]) | |
| if 'range_start' in final_params and final_params['range_start'] is not None: | |
| cmd.extend(["--range_start", str(final_params['range_start'])]) | |
| if 'range_count' in final_params: | |
| cmd.extend(["--range_count", str(final_params['range_count'])]) | |
| if 'random_n' in final_params and final_params['random_n'] is not None: | |
| cmd.extend(["--random_n", str(final_params['random_n'])]) | |
| if 'random_n_with_replacement' in final_params and final_params['random_n_with_replacement'] is not None: | |
| cmd.extend(["--random_n_with_replacement", str(final_params['random_n_with_replacement'])]) | |
| if 'random_seed' in final_params: | |
| cmd.extend(["--random_seed", str(final_params['random_seed'])]) | |
| # 7. Show standalone alternative | |
| show_standalone_alternatives(cmd, f"Evaluate on {eval_dataset}" + | |
| (f"/{eval_category}" if eval_category else "")) | |
| # 8. Execute | |
| if dry_run: | |
| console.print("[yellow]Dry run - not executing[/yellow]") | |
| return | |
| console.print("[bold]Running:[/bold]", " ".join(cmd)) | |
| console.print() | |
| result = subprocess.run(cmd) | |
| sys.exit(result.returncode) | |
| @main.command('anomaly-detect') | |
| # Config file option (FIRST - allows loading all parameters from JSON) | |
| @click.option('--config', type=click.Path(exists=True), | |
| help='JSON config file (base + environment overrides)') | |
| # Mode selection | |
| @click.option('--mode', type=click.Choice(['detect_and_evaluate_without_cache', 'detect_and_evaluate', | |
| 'detect_only', 'evaluate_only']), | |
| default='detect_and_evaluate', | |
| help='Execution mode (default: detect_and_evaluate)') | |
| @click.option('--npy-cache-dir', type=click.Path(), | |
| help='Directory for NPY cache files (defaults to {save_plot_path}/npy_cache)') | |
| # Core parameters | |
| @click.option('--model-path', '--checkpoint', 'model_path', | |
| type=click.Path(exists=True), | |
| help='Path to trained model checkpoint') | |
| @click.option('--dataset', type=click.Choice(['mvtec', 'mvtec2', 'sdp', 'sdp_orig', 'sdp_orig_test', | |
| 'sdp_ok_ng', 'sdp_ok_ng_crop', 'sdp_ok_ng_resize_1024', | |
| 'sdp_ok_ng_resize_2048', 'visa', 'pcb', 'avp', 'iris']), | |
| default='mvtec', | |
| help='Dataset to evaluate') | |
| @click.option('--object-category', '--category', 'category', | |
| default='all', | |
| help='Dataset category') | |
| @click.option('--anomaly-class', | |
| default='all', | |
| help='Anomaly class filter') | |
| @click.option('--data-dir', type=click.Path(exists=True), | |
| help='Path to dataset directory') | |
| @click.option('--csv-split-file', | |
| help='Custom CSV split file (overrides dataset default)') | |
| # Model configuration | |
| @click.option('--model-size', type=click.Choice(['UNet_XS','UNet_S','UNet_M','UNet_L','UNet_XL']), | |
| default='UNet_L', | |
| help='Model size') | |
| @click.option('--vae-type', type=click.Choice(['ema', 'mse']), | |
| default='ema', | |
| help='VAE type (ema or mse)') | |
| @click.option('--image-size', type=str, | |
| help='Image size (integer or "None" for dataset default)') | |
| @click.option('--crop-size', type=int, default=256, | |
| help='Crop size') | |
| @click.option('--center-crop', is_flag=True, default=True, | |
| help='Enable center crop') | |
| # Strategy parameters | |
| @click.option('--augmentation-strategy', | |
| type=click.Choice(['center_crop', 'center_crop_with_affine', | |
| 'random_crop', 'multi_stage', 'none']), | |
| help='Augmentation strategy for evaluation') | |
| @click.option('--image-loading-strategy', | |
| type=click.Choice(['resize_first', 'keep_original', 'adaptive']), | |
| help='Image loading strategy (preprocessing before augmentation)') | |
| # Evaluation parameters | |
| @click.option('--reverse-steps', type=int, default=5, | |
| help='Number of reverse diffusion steps') | |
| @click.option('--batch-size', type=int, default=1, | |
| help='DataLoader batch size (process B images together)') | |
| @click.option('--num-workers', type=int, default=4, | |
| help='Number of DataLoader workers') | |
| # TTA parameters (batched version - AnomalyDetector specific) | |
| @click.option('--pad-px', type=int, default=None, | |
| help='Shift range for TTA (±pad_px pixels). Set to enable TTA, None to disable') | |
| @click.option('--img-enlarge-px', type=int, default=4, | |
| help='Image enlargement for interpolation method') | |
| @click.option('--stride', type=int, default=1, | |
| help='Shift increment (1=every pixel, 2=skip pixels)') | |
| @click.option('--directions', multiple=True, callback=parse_directions, | |
| help='Shift directions (space-separated: h v diag all rhombus, or use multiple --directions flags)') | |
| @click.option('--shift-method', type=click.Choice(['interpolate', 'mirror']), | |
| default='interpolate', | |
| help='Shift augmentation method') | |
| @click.option('--fuse-method', type=click.Choice(['mean', 'median', 'lowest', 'pct25', 'pct75']), | |
| default='mean', | |
| help='Method to fuse multiple shifted predictions') | |
| @click.option('--tta-batch-size', type=int, default=8, | |
| help='Batch size for TTA shift processing (combined with DataLoader batch)') | |
| # Anomaly detection parameters | |
| @click.option('--anomaly-threshold', type=float, default=20, | |
| help='Threshold for anomaly detection (0-255 range, for full_image_eval)') | |
| @click.option('--anomaly-min-area', type=int, default=10, | |
| help='Minimum area for detected anomaly regions (for full_image_eval)') | |
| # Output parameters | |
| @click.option('--save-plot-path', '--output-dir', 'output_dir', | |
| type=click.Path(), | |
| help='Results output directory') | |
| @click.option('--save-evaluation-images', type=str, default='false', | |
| help='Save evaluation images in organized subfolders (true/false)') | |
| @click.option('--annotation-dir', type=click.Path(), | |
| help='Directory containing annotation JSON files for TP/FP/FN/TN classification') | |
| @click.option('--save-reconstructions', is_flag=True, default=None, | |
| help='Save reconstructed images (VAE and Deco-Diff) as PNG') | |
| @click.option('--save-reconstructions-npy', is_flag=True, default=None, | |
| help='Save raw reconstruction arrays as NPY files') | |
| @click.option('--save-raw-anomaly-maps', is_flag=True, default=None, | |
| help='Save raw anomaly map arrays as NPY files (image-level)') | |
| @click.option('--save-image-variants', type=str, multiple=True, | |
| help='Which image variants to save (can specify multiple: continuous, binary, grayscale, absolute, all)') | |
| @click.option('--save-colormaps', type=str, multiple=True, | |
| help='Colormaps to use for continuous maps (can specify multiple: jet, hot, viridis, plasma, magma, inferno, turbo)') | |
| @click.option('--save-normalization', type=click.Choice(['minmax', 'zscore', 'percentile', 'raw']), | |
| help='Normalization method for visualization') | |
| @click.option('--normal-center', type=float, | |
| help='Center value for grayscale normalization (default: 125.0)') | |
| @click.option('--save-binary-thresholds', type=float, multiple=True, | |
| help='Threshold values for binary maps (can specify multiple, e.g., 5.0 10.0 15.0)') | |
| @click.option('--checkpoint-enabled', is_flag=True, default=False, | |
| help='Enable checkpoint functionality for image saving') | |
| # NPY file saving (for regression testing) | |
| @click.option('--save-npy-files', is_flag=True, default=False, | |
| help='Save anomaly maps as .npy files (for regression testing)') | |
| @click.option('--npy-save-dir', type=click.Path(), | |
| help='Directory to save .npy files') | |
| # Filename generation strategy | |
| @click.option('--filename-strategy', type=click.Choice(['auto', 'simple', 'anomaly_type', 'full_path', | |
| 'incremental', 'dataset_index']), | |
| default='auto', | |
| help='Filename generation strategy') | |
| @click.option('--regression-test-mode', is_flag=True, default=False, | |
| help='Enable regression test mode (forces simple filename strategy)') | |
| # Dataset sampling options | |
| @click.option('--bootstrap-samples', type=int, default=None, | |
| help='Bootstrap sampling: number of samples (with replacement)') | |
| @click.option('--bootstrap-seed', type=int, default=42, | |
| help='Random seed for bootstrap sampling') | |
| @click.option('--first-n', type=int, default=None, | |
| help='Take first N samples from dataset') | |
| @click.option('--last-n', type=int, default=None, | |
| help='Take last N samples from dataset') | |
| @click.option('--range-start', type=int, default=None, | |
| help='Range sampling: start index (0-based)') | |
| @click.option('--range-count', type=int, default=10, | |
| help='Range sampling: number of samples (default: 10)') | |
| @click.option('--random-n', type=int, default=None, | |
| help='Random N samples without replacement') | |
| @click.option('--random-n-with-replacement', type=int, default=None, | |
| help='Random N samples with replacement') | |
| @click.option('--random-seed', type=int, default=42, | |
| help='Random seed for random sampling methods') | |
| # I/O optimization parameters | |
| @click.option('--use_consolidated_npz', is_flag=True, default=False, | |
| help='Use single compressed .npz file per image instead of multiple .npy files (4-6× faster I/O)') | |
| # Utility | |
| @click.option('--dry-run/--no-dry-run', default=False, | |
| help='Show commands without executing') | |
| def anomaly_detect(config, mode, npy_cache_dir, model_path, dataset, category, anomaly_class, data_dir, csv_split_file, | |
| model_size, vae_type, image_size, crop_size, center_crop, | |
| augmentation_strategy, image_loading_strategy, | |
| reverse_steps, batch_size, num_workers, | |
| pad_px, img_enlarge_px, stride, directions, shift_method, fuse_method, tta_batch_size, | |
| anomaly_threshold, anomaly_min_area, | |
| output_dir, save_evaluation_images, annotation_dir, save_reconstructions, save_reconstructions_npy, | |
| save_raw_anomaly_maps, save_image_variants, save_colormaps, save_normalization, | |
| normal_center, save_binary_thresholds, checkpoint_enabled, | |
| save_npy_files, npy_save_dir, | |
| filename_strategy, regression_test_mode, | |
| bootstrap_samples, bootstrap_seed, first_n, last_n, range_start, range_count, | |
| random_n, random_n_with_replacement, random_seed, | |
| use_consolidated_npz, | |
| dry_run): | |
| """Run anomaly detection using modern AnomalyDetector (consolidated evaluator). | |
| Uses the modern AnomalyDetector implementation which consolidates DecodiffEvaluator | |
| and DecodiffEvaluateProcessor with efficient batched TTA (~6.6× speedup). | |
| Features: | |
| - Batched TTA for ~6.6× speedup vs sequential TTA | |
| - Configurable shifts: --pad-px, --stride, --directions | |
| - Clean architecture without legacy code | |
| Examples: | |
| # Basic evaluation | |
| hae anomaly-detect --model-path models/bottle.pth --dataset mvtec --category bottle | |
| # With TTA enabled | |
| hae anomaly-detect --model-path models/bottle.pth --pad-px 4 --tta-batch-size 8 | |
| # With config file | |
| hae anomaly-detect --config config/eval.json --pad-px 4 | |
| """ | |
| # 1. Load config if provided | |
| config_dict = {} | |
| if config: | |
| from dioodmi.cli.config import load_config_with_env, detect_environment | |
| from pathlib import Path | |
| config_dict = load_config_with_env(Path(config)) | |
| console.print(f"[cyan]📄 Loaded config:[/cyan] {config}") | |
| console.print(f" Environment: {detect_environment()}") | |
| # 2. Build parameter dictionary from CLI arguments | |
| cli_params = { | |
| 'mode': mode, | |
| 'npy_cache_dir': npy_cache_dir, | |
| 'model_path': model_path, | |
| 'dataset': dataset, | |
| 'object_category': category, | |
| 'anomaly_class': anomaly_class, | |
| 'data_dir': data_dir, | |
| 'csv_split_file': csv_split_file, | |
| 'model_size': model_size, | |
| 'vae_type': vae_type, | |
| 'image_size': image_size, | |
| 'crop_size': crop_size, | |
| 'center_crop': center_crop, | |
| 'augmentation_strategy': augmentation_strategy, | |
| 'image_loading_strategy': image_loading_strategy, | |
| 'reverse_steps': reverse_steps, | |
| 'batch_size': batch_size, | |
| 'num_workers': num_workers, | |
| 'pad_px': pad_px, | |
| 'img_enlarge_px': img_enlarge_px, | |
| 'stride': stride, | |
| 'directions': directions, | |
| 'shift_method': shift_method, | |
| 'fuse_method': fuse_method, | |
| 'tta_batch_size': tta_batch_size, | |
| 'anomaly_threshold': anomaly_threshold, | |
| 'anomaly_min_area': anomaly_min_area, | |
| 'save_plot_path': output_dir, | |
| 'save_evaluation_images': save_evaluation_images, | |
| 'annotation_dir': annotation_dir, | |
| 'save_reconstructions': save_reconstructions, | |
| 'save_reconstructions_npy': save_reconstructions_npy, | |
| 'save_raw_anomaly_maps': save_raw_anomaly_maps, | |
| 'save_image_variants': list(save_image_variants) if save_image_variants else None, | |
| 'save_colormaps': list(save_colormaps) if save_colormaps else None, | |
| 'save_normalization': save_normalization, | |
| 'normal_center': normal_center, | |
| 'save_binary_thresholds': list(save_binary_thresholds) if save_binary_thresholds else None, | |
| 'checkpoint_enabled': checkpoint_enabled, | |
| 'save_npy_files': save_npy_files, | |
| 'npy_save_dir': npy_save_dir, | |
| 'filename_strategy': filename_strategy, | |
| 'regression_test_mode': regression_test_mode, | |
| # Sampling parameters | |
| 'bootstrap_samples': bootstrap_samples, | |
| 'bootstrap_seed': bootstrap_seed, | |
| 'first_n': first_n, | |
| 'last_n': last_n, | |
| 'range_start': range_start, | |
| 'range_count': range_count, | |
| 'random_n': random_n, | |
| 'random_n_with_replacement': random_n_with_replacement, | |
| 'random_seed': random_seed, | |
| 'use_consolidated_npz': use_consolidated_npz, | |
| } | |
| # 3. Merge: config values first, then override with CLI args | |
| final_params = config_dict.copy() | |
| for key, value in cli_params.items(): | |
| if value is not None: # CLI arg was explicitly provided | |
| final_params[key] = value | |
| # 4. Validate required parameters | |
| if 'model_path' not in final_params or final_params['model_path'] is None: | |
| console.print("[red]❌ Error: --model-path required (or specify in config)[/red]") | |
| sys.exit(1) | |
| # Display evaluation info | |
| eval_dataset = final_params.get('dataset', 'mvtec') | |
| eval_category = final_params.get('object_category', 'all') | |
| console.print(f"[bold green]🔍 Anomaly Detection (Modern Version)[/bold green]") | |
| console.print(f"[bold]Dataset:[/bold] {eval_dataset}" + | |
| (f"/{eval_category}" if eval_category != 'all' else "")) | |
| console.print(f"[bold]Model:[/bold] {final_params.get('model_size', 'UNet_L')}\n") | |
| # Show TTA status | |
| if final_params.get('pad_px') is not None: | |
| directions_display = final_params.get('directions', ('h', 'v')) | |
| if isinstance(directions_display, tuple): | |
| directions_display = ' '.join(directions_display) | |
| console.print(f"[cyan]✓ TTA Enabled:[/cyan]") | |
| console.print(f" Shift range: ±{final_params['pad_px']} pixels") | |
| console.print(f" Stride: {final_params.get('stride', 1)}") | |
| console.print(f" Directions: {directions_display}") | |
| console.print(f" TTA batch size: {final_params.get('tta_batch_size', 8)}") | |
| console.print(f" Expected speedup: ~6.6× vs sequential TTA\n") | |
| else: | |
| console.print(f"[yellow]✗ TTA Disabled (use --pad-px to enable)[/yellow]\n") | |
| # 5. Build platform-specific command | |
| if sys.platform.startswith('win'): | |
| cmd = ["py", "-3.11", "-m", "dioodmi.cli.ad"] | |
| else: | |
| cmd = ["dioodmi-ad"] | |
| # 6. Add all parameters to command | |
| # Mode selection (must come first for positional parsing) | |
| if 'mode' in final_params: | |
| cmd.extend(["--mode", final_params['mode']]) | |
| if 'npy_cache_dir' in final_params and final_params['npy_cache_dir']: | |
| cmd.extend(["--npy_cache_dir", final_params['npy_cache_dir']]) | |
| cmd.extend(["--model_path", str(final_params['model_path'])]) | |
| # Dataset parameters | |
| if 'dataset' in final_params: | |
| cmd.extend(["--dataset", final_params['dataset']]) | |
| if 'object_category' in final_params: | |
| cmd.extend(["--object_category", final_params['object_category']]) | |
| if 'anomaly_class' in final_params: | |
| cmd.extend(["--anomaly_class", final_params['anomaly_class']]) | |
| if 'data_dir' in final_params: | |
| cmd.extend(["--data_dir", final_params['data_dir']]) | |
| if 'csv_split_file' in final_params: | |
| cmd.extend(["--csv_split_file", final_params['csv_split_file']]) | |
| # Model configuration | |
| if 'model_size' in final_params: | |
| cmd.extend(["--model_size", final_params['model_size']]) | |
| if 'vae_type' in final_params: | |
| cmd.extend(["--vae_type", final_params['vae_type']]) | |
| if 'image_size' in final_params: | |
| if final_params['image_size'] is None: | |
| cmd.extend(["--image_size", "None"]) | |
| else: | |
| cmd.extend(["--image_size", str(final_params['image_size'])]) | |
| if 'crop_size' in final_params: | |
| cmd.extend(["--crop_size", str(final_params['crop_size'])]) | |
| if 'center_crop' in final_params: | |
| cmd.extend(["--center_crop", "true" if final_params['center_crop'] else "false"]) | |
| # Strategy parameters | |
| if 'augmentation_strategy' in final_params and final_params['augmentation_strategy']: | |
| cmd.extend(["--augmentation_strategy", final_params['augmentation_strategy']]) | |
| if 'image_loading_strategy' in final_params and final_params['image_loading_strategy']: | |
| cmd.extend(["--image_loading_strategy", final_params['image_loading_strategy']]) | |
| # Evaluation parameters | |
| if 'reverse_steps' in final_params: | |
| cmd.extend(["--reverse_steps", str(final_params['reverse_steps'])]) | |
| if 'batch_size' in final_params: | |
| cmd.extend(["--batch_size", str(final_params['batch_size'])]) | |
| if 'num_workers' in final_params: | |
| cmd.extend(["--num_workers", str(final_params['num_workers'])]) | |
| # TTA parameters (only pass if pad_px is set, indicating TTA is enabled) | |
| if 'pad_px' in final_params and final_params['pad_px'] is not None: | |
| cmd.extend(["--pad_px", str(final_params['pad_px'])]) | |
| if 'img_enlarge_px' in final_params: | |
| cmd.extend(["--img_enlarge_px", str(final_params['img_enlarge_px'])]) | |
| if 'stride' in final_params: | |
| cmd.extend(["--stride", str(final_params['stride'])]) | |
| if 'directions' in final_params and final_params['directions']: | |
| # dioodmi-ad uses argparse nargs="+" so it expects: --directions h v diag | |
| directions = final_params['directions'] | |
| if isinstance(directions, (tuple, list)): | |
| cmd.extend(["--directions"] + list(directions)) | |
| else: | |
| # Handle string case (from config file): "h v diag" -> ["h", "v", "diag"] | |
| cmd.extend(["--directions"] + directions.split()) | |
| if 'shift_method' in final_params: | |
| cmd.extend(["--shift_method", final_params['shift_method']]) | |
| if 'fuse_method' in final_params: | |
| cmd.extend(["--fuse_method", final_params['fuse_method']]) | |
| if 'tta_batch_size' in final_params: | |
| cmd.extend(["--tta_batch_size", str(final_params['tta_batch_size'])]) | |
| # Anomaly detection parameters | |
| if 'anomaly_threshold' in final_params and final_params['anomaly_threshold'] is not None: | |
| cmd.extend(["--anomaly_threshold", str(final_params['anomaly_threshold'])]) | |
| if 'anomaly_min_area' in final_params and final_params['anomaly_min_area'] is not None: | |
| cmd.extend(["--anomaly_min_area", str(final_params['anomaly_min_area'])]) | |
| # Output parameters | |
| if 'save_plot_path' in final_params: | |
| cmd.extend(["--save_plot_path", final_params['save_plot_path']]) | |
| if 'save_evaluation_images' in final_params: | |
| # Convert to boolean and pass as string value | |
| save_eval_bool = str(final_params['save_evaluation_images']).lower() in ('true', '1', 'yes') | |
| cmd.extend(["--save_evaluation_images", "true" if save_eval_bool else "false"]) | |
| if 'annotation_dir' in final_params and final_params['annotation_dir'] is not None: | |
| cmd.extend(["--annotation_dir", final_params['annotation_dir']]) | |
| if 'save_reconstructions' in final_params and final_params['save_reconstructions'] is not None: | |
| cmd.extend(["--save_reconstructions", "true" if final_params['save_reconstructions'] else "false"]) | |
| if 'save_reconstructions_npy' in final_params and final_params['save_reconstructions_npy'] is not None: | |
| cmd.extend(["--save_reconstructions_npy", "true" if final_params['save_reconstructions_npy'] else "false"]) | |
| if 'save_raw_anomaly_maps' in final_params and final_params['save_raw_anomaly_maps'] is not None: | |
| cmd.extend(["--save_raw_anomaly_maps", "true" if final_params['save_raw_anomaly_maps'] else "false"]) | |
| if 'save_image_variants' in final_params and final_params['save_image_variants']: | |
| cmd.extend(["--save-image-variants"] + [str(v) for v in final_params['save_image_variants']]) | |
| if 'save_colormaps' in final_params and final_params['save_colormaps']: | |
| cmd.extend(["--save-colormaps"] + [str(v) for v in final_params['save_colormaps']]) | |
| if 'save_normalization' in final_params and final_params['save_normalization']: | |
| cmd.extend(["--save-normalization", str(final_params['save_normalization'])]) | |
| if 'normal_center' in final_params and final_params['normal_center'] is not None: | |
| cmd.extend(["--normal-center", str(final_params['normal_center'])]) | |
| if 'save_binary_thresholds' in final_params and final_params['save_binary_thresholds']: | |
| cmd.extend(["--save-binary-thresholds"] + [str(v) for v in final_params['save_binary_thresholds']]) | |
| if 'checkpoint_enabled' in final_params and final_params['checkpoint_enabled'] is not None: | |
| cmd.extend(["--checkpoint_enabled", "true" if final_params['checkpoint_enabled'] else "false"]) | |
| # NPY file saving (for regression testing) | |
| if 'save_npy_files' in final_params and final_params['save_npy_files']: | |
| cmd.extend(["--save_npy_files", "true"]) | |
| if 'npy_save_dir' in final_params and final_params['npy_save_dir']: | |
| cmd.extend(["--npy_save_dir", final_params['npy_save_dir']]) | |
| # Filename strategy | |
| if 'filename_strategy' in final_params: | |
| cmd.extend(["--filename_strategy", final_params['filename_strategy']]) | |
| if 'regression_test_mode' in final_params and final_params['regression_test_mode']: | |
| cmd.append("--regression_test_mode") | |
| # Sampling parameters | |
| if 'bootstrap_samples' in final_params and final_params['bootstrap_samples'] is not None: | |
| cmd.extend(["--bootstrap-samples", str(final_params['bootstrap_samples'])]) | |
| if 'bootstrap_seed' in final_params: | |
| cmd.extend(["--bootstrap-seed", str(final_params['bootstrap_seed'])]) | |
| if 'first_n' in final_params and final_params['first_n'] is not None: | |
| cmd.extend(["--first-n", str(final_params['first_n'])]) | |
| if 'last_n' in final_params and final_params['last_n'] is not None: | |
| cmd.extend(["--last-n", str(final_params['last_n'])]) | |
| if 'range_start' in final_params and final_params['range_start'] is not None: | |
| cmd.extend(["--range-start", str(final_params['range_start'])]) | |
| if 'range_count' in final_params: | |
| cmd.extend(["--range-count", str(final_params['range_count'])]) | |
| if 'random_n' in final_params and final_params['random_n'] is not None: | |
| cmd.extend(["--random-n", str(final_params['random_n'])]) | |
| if 'random_seed' in final_params: | |
| cmd.extend(["--random-seed", str(final_params['random_seed'])]) | |
| if 'random_n_with_replacement' in final_params and final_params['random_n_with_replacement'] is not None: | |
| cmd.extend(["--random-n-with-replacement", str(final_params['random_n_with_replacement'])]) | |
| if 'random_seed' in final_params: | |
| cmd.extend(["--random-seed", str(final_params['random_seed'])]) | |
| # I/O optimization parameters | |
| if 'use_consolidated_npz' in final_params and final_params['use_consolidated_npz']: | |
| cmd.extend(["--use_consolidated_npz", "true"]) | |
| # 7. Show standalone alternative | |
| show_standalone_alternatives(cmd, f"Anomaly Detection on {eval_dataset}" + | |
| (f"/{eval_category}" if eval_category != 'all' else "")) | |
| # 8. Execute | |
| if dry_run: | |
| console.print("[yellow]Dry run - not executing[/yellow]") | |
| return | |
| console.print("[bold]Running:[/bold]", " ".join(cmd)) | |
| console.print() | |
| result = subprocess.run(cmd) | |
| sys.exit(result.returncode) | |
| @main.command('eval-process') | |
| # Config file option (FIRST - allows loading all parameters from JSON) | |
| @click.option('--config', type=click.Path(exists=True), | |
| help='JSON config file (base + environment overrides)') | |
| # Mode selection (REQUIRED) | |
| @click.option('--mode', type=click.Choice(['save_only', 'process_only', 'save_and_process', | |
| 'full_pipeline', 'full_pipeline_with_saving_npy', | |
| 'save_and_process_with_bootstrap']), | |
| help='Execution mode (REQUIRED unless specified in config)') | |
| # Core parameters | |
| @click.option('--pretrained', '--model-path', 'pretrained', | |
| type=click.Path(exists=True), | |
| help='Path to trained model checkpoint (REQUIRED for save modes)') | |
| @click.option('--annotation-dir', type=click.Path(), | |
| help='Directory containing annotation files (REQUIRED)') | |
| @click.option('--dataset', type=click.Choice(['mvtec', 'mvtec2', 'sdp', 'sdp_orig', 'sdp_orig_test', | |
| 'sdp_ok_ng', 'sdp_ok_ng_crop', 'sdp_ok_ng_resize_1024', | |
| 'sdp_ok_ng_resize_2048', 'visa', 'pcb', 'avp', 'iris']), | |
| default='mvtec', | |
| help='Dataset to evaluate') | |
| @click.option('--object-class', default='all', | |
| help='Object class to process') | |
| @click.option('--anomaly-class', | |
| help='Anomaly class filter') | |
| @click.option('--data-dir', type=click.Path(exists=True), | |
| help='Path to dataset directory') | |
| @click.option('--csv-split-file', | |
| help='Custom CSV split file (overrides dataset default)') | |
| # Model configuration | |
| @click.option('--model-size', type=click.Choice(['UNet_XS','UNet_S','UNet_M','UNet_L','UNet_XL']), | |
| default='UNet_L', | |
| help='Model size') | |
| @click.option('--vae-type', type=click.Choice(['ema', 'mse']), | |
| default='ema', | |
| help='VAE type (ema or mse)') | |
| @click.option('--image-size', type=int, | |
| help='Image size (integer)') | |
| @click.option('--center-size', type=int, | |
| help='Center size') | |
| @click.option('--center-crop', type=str, | |
| help='Center crop (string)') | |
| # Processing parameters | |
| @click.option('--patch-size', type=int, default=128, | |
| help='Patch size for image processing') | |
| @click.option('--stride', type=int, | |
| help='Stride for patch extraction (None = no overlap)') | |
| @click.option('--irregular-patch', is_flag=True, default=False, | |
| help='Use irregular patch for image processing') | |
| @click.option('--reverse-steps', type=int, default=5, | |
| help='Number of reverse diffusion steps') | |
| @click.option('--batch-size', type=int, default=64, | |
| help='Batch size for processing') | |
| @click.option('--batch-num', type=int, default=12, | |
| help='Number of batches to process') | |
| @click.option('--num-workers', type=int, default=0, | |
| help='Number of DataLoader workers') | |
| @click.option('--split', type=str, default='test', | |
| help='Data split to process') | |
| # Anomaly detection parameters | |
| @click.option('--anomaly-binary-threshold', type=int, default=5, | |
| help='Binary threshold for anomaly detection') | |
| @click.option('--anomaly-pixel-num-threshold', type=int, default=0, | |
| help='Pixel number threshold') | |
| @click.option('--adaptive-threshold', type=float, default=0.1, | |
| help='Adaptive threshold for contour-based masks') | |
| # Output parameters | |
| @click.option('--results-dir', type=click.Path(), | |
| help='Results directory (default: auto-generated)') | |
| @click.option('--tag', type=str, | |
| help='Custom tag for output directory') | |
| @click.option('--enable-excel-report', is_flag=True, default=False, | |
| help='Generate Excel report') | |
| @click.option('--enable-save-image-results', is_flag=True, default=False, | |
| help='Save image results') | |
| @click.option('--enable-save-whole-image-results', is_flag=True, default=False, | |
| help='Save whole image results') | |
| @click.option('--save-reconstructions', is_flag=True, default=None, | |
| help='Save reconstructed images (VAE and Deco-Diff) as PNG') | |
| @click.option('--save-reconstructions-npy', is_flag=True, default=None, | |
| help='Save raw reconstruction arrays as NPY files') | |
| @click.option('--save-raw-anomaly-maps', is_flag=True, default=None, | |
| help='Save raw anomaly map arrays as NPY files (image-level)') | |
| @click.option('--save-image-variants', type=str, multiple=True, | |
| help='Which image variants to save (can specify multiple: continuous, binary, grayscale, absolute, all)') | |
| @click.option('--save-colormaps', type=str, multiple=True, | |
| help='Colormaps to use for continuous maps (can specify multiple: jet, hot, viridis, plasma, magma, inferno, turbo)') | |
| @click.option('--save-normalization', type=click.Choice(['minmax', 'zscore', 'percentile', 'raw']), | |
| help='Normalization method for visualization') | |
| @click.option('--normal-center', type=float, | |
| help='Center value for grayscale normalization (default: 125.0)') | |
| @click.option('--save-binary-thresholds', type=float, multiple=True, | |
| help='Threshold values for binary maps (can specify multiple, e.g., 5.0 10.0 15.0)') | |
| @click.option('--enable-confusion-matrix', is_flag=True, default=False, | |
| help='Create confusion matrix') | |
| # Filename generation strategy | |
| @click.option('--filename-strategy', type=click.Choice(['auto', 'simple', 'anomaly_type', 'full_path', | |
| 'incremental', 'dataset_index']), | |
| default='auto', | |
| help='Filename generation strategy') | |
| @click.option('--regression-test-mode', is_flag=True, default=False, | |
| help='Enable regression test mode') | |
| # Bootstrap parameters | |
| @click.option('--bootstrap-samples', type=int, | |
| help='Enable bootstrap sampling: save only N randomly sampled patches') | |
| @click.option('--bootstrap-seed', type=int, default=42, | |
| help='Random seed for bootstrap sampling') | |
| @click.option('--bootstrap-sampling-trial-number', type=int, default=1, | |
| help='Number of bootstrap resampling trials') | |
| # Utility | |
| @click.option('--dry-run/--no-dry-run', default=False, | |
| help='Show commands without executing') | |
| def eval_process(config, mode, pretrained, annotation_dir, dataset, object_class, anomaly_class, | |
| data_dir, csv_split_file, model_size, vae_type, image_size, center_size, | |
| center_crop, patch_size, stride, irregular_patch, reverse_steps, batch_size, | |
| batch_num, num_workers, split, anomaly_binary_threshold, | |
| anomaly_pixel_num_threshold, adaptive_threshold, results_dir, tag, | |
| enable_excel_report, enable_save_image_results, enable_save_whole_image_results, | |
| save_reconstructions, save_reconstructions_npy, save_raw_anomaly_maps, | |
| save_image_variants, save_colormaps, save_normalization, normal_center, | |
| save_binary_thresholds, enable_confusion_matrix, | |
| filename_strategy, regression_test_mode, bootstrap_samples, bootstrap_seed, | |
| bootstrap_sampling_trial_number, dry_run): | |
| """Run evaluation processing using DecodiffEvaluateProcessor. | |
| Provides multiple execution modes for evaluation and processing: | |
| - save_only: Save .npy files and diff images only | |
| - process_only: Process existing .npy files for categorization | |
| - save_and_process: Save and immediately process | |
| - full_pipeline: Complete pipeline without saving intermediates | |
| - full_pipeline_with_saving_npy: Complete pipeline with saving NPY files | |
| - save_and_process_with_bootstrap: Bootstrap evaluation with two-stage sampling | |
| Examples: | |
| # Full pipeline mode | |
| hae eval-process --mode full_pipeline --pretrained models/bottle.pth --annotation-dir ./annotations | |
| # Process only existing NPY files | |
| hae eval-process --mode process_only --annotation-dir ./annotations | |
| # With config file | |
| hae eval-process --config config/process.json --mode full_pipeline | |
| """ | |
| # 1. Load config if provided | |
| config_dict = {} | |
| if config: | |
| from dioodmi.cli.config import load_config_with_env, detect_environment | |
| from pathlib import Path | |
| config_dict = load_config_with_env(Path(config)) | |
| console.print(f"[cyan]📄 Loaded config:[/cyan] {config}") | |
| console.print(f" Environment: {detect_environment()}") | |
| # 2. Build parameter dictionary from CLI arguments | |
| cli_params = { | |
| 'mode': mode, | |
| 'pretrained': pretrained, | |
| 'annotation_dir': annotation_dir, | |
| 'dataset': dataset, | |
| 'object_class': object_class, | |
| 'anomaly_class': anomaly_class, | |
| 'data_dir': data_dir, | |
| 'csv_split_file': csv_split_file, | |
| 'model_size': model_size, | |
| 'vae_type': vae_type, | |
| 'image_size': image_size, | |
| 'center_size': center_size, | |
| 'center_crop': center_crop, | |
| 'patch_size': patch_size, | |
| 'stride': stride, | |
| 'irregular_patch': irregular_patch, | |
| 'reverse_steps': reverse_steps, | |
| 'batch_size': batch_size, | |
| 'batch_num': batch_num, | |
| 'num_workers': num_workers, | |
| 'split': split, | |
| 'anomaly_binary_threshold': anomaly_binary_threshold, | |
| 'anomaly_pixel_num_threshold': anomaly_pixel_num_threshold, | |
| 'adaptive_threshold': adaptive_threshold, | |
| 'results_dir': results_dir, | |
| 'tag': tag, | |
| 'enable_excel_report': enable_excel_report, | |
| 'enable_save_image_results': enable_save_image_results, | |
| 'enable_save_whole_image_results': enable_save_whole_image_results, | |
| 'save_reconstructions': save_reconstructions, | |
| 'save_reconstructions_npy': save_reconstructions_npy, | |
| 'save_raw_anomaly_maps': save_raw_anomaly_maps, | |
| 'save_image_variants': list(save_image_variants) if save_image_variants else None, | |
| 'save_colormaps': list(save_colormaps) if save_colormaps else None, | |
| 'save_normalization': save_normalization, | |
| 'normal_center': normal_center, | |
| 'save_binary_thresholds': list(save_binary_thresholds) if save_binary_thresholds else None, | |
| 'enable_confusion_matrix': enable_confusion_matrix, | |
| 'filename_strategy': filename_strategy, | |
| 'regression_test_mode': regression_test_mode, | |
| 'bootstrap_samples': bootstrap_samples, | |
| 'bootstrap_seed': bootstrap_seed, | |
| 'bootstrap_sampling_trial_number': bootstrap_sampling_trial_number, | |
| } | |
| # 3. Merge: config values first, then override with CLI args | |
| final_params = config_dict.copy() | |
| for key, value in cli_params.items(): | |
| if value is not None: # CLI arg was explicitly provided | |
| final_params[key] = value | |
| # 4. Validate required parameters | |
| if 'mode' not in final_params or final_params['mode'] is None: | |
| console.print("[red]❌ Error: --mode required (or specify in config)[/red]") | |
| sys.exit(1) | |
| if 'annotation_dir' not in final_params or final_params['annotation_dir'] is None: | |
| console.print("[red]❌ Error: --annotation-dir required (or specify in config)[/red]") | |
| sys.exit(1) | |
| # Display evaluation info | |
| eval_mode = final_params.get('mode', 'unknown') | |
| eval_dataset = final_params.get('dataset', 'mvtec') | |
| eval_object = final_params.get('object_class', 'all') | |
| console.print(f"[bold cyan]⚙️ Evaluation Processing (DecodiffEvaluateProcessor)[/bold cyan]") | |
| console.print(f"[bold]Mode:[/bold] {eval_mode}") | |
| console.print(f"[bold]Dataset:[/bold] {eval_dataset}" + | |
| (f"/{eval_object}" if eval_object != 'all' else "")) | |
| console.print(f"[bold]Model:[/bold] {final_params.get('model_size', 'UNet_L')}\n") | |
| # 5. Build platform-specific command | |
| if sys.platform.startswith('win'): | |
| cmd = ["py", "-3.11", "-m", "dioodmi.cli.process"] | |
| else: | |
| cmd = ["dioodmi-process"] | |
| # 6. Add all parameters to command | |
| cmd.extend(["--mode", final_params['mode']]) | |
| if 'pretrained' in final_params and final_params['pretrained']: | |
| cmd.extend(["--pretrained", str(final_params['pretrained'])]) | |
| if 'annotation_dir' in final_params: | |
| cmd.extend(["--annotation_dir", final_params['annotation_dir']]) | |
| if 'dataset' in final_params: | |
| cmd.extend(["--dataset", final_params['dataset']]) | |
| if 'object_class' in final_params: | |
| cmd.extend(["--object_class", final_params['object_class']]) | |
| if 'anomaly_class' in final_params: | |
| cmd.extend(["--anomaly_class", final_params['anomaly_class']]) | |
| if 'data_dir' in final_params: | |
| cmd.extend(["--data_dir", final_params['data_dir']]) | |
| if 'csv_split_file' in final_params: | |
| cmd.extend(["--csv_split_file", final_params['csv_split_file']]) | |
| # Model configuration | |
| if 'model_size' in final_params: | |
| cmd.extend(["--model_size", final_params['model_size']]) | |
| if 'vae_type' in final_params: | |
| cmd.extend(["--vae_type", final_params['vae_type']]) | |
| if 'image_size' in final_params: | |
| cmd.extend(["--image_size", str(final_params['image_size'])]) | |
| if 'center_size' in final_params: | |
| cmd.extend(["--center_size", str(final_params['center_size'])]) | |
| if 'center_crop' in final_params: | |
| cmd.extend(["--center_crop", final_params['center_crop']]) | |
| # Processing parameters | |
| if 'patch_size' in final_params: | |
| cmd.extend(["--patch_size", str(final_params['patch_size'])]) | |
| if 'stride' in final_params and final_params['stride'] is not None: | |
| # Handle string "null" or "None" from YAML substitution | |
| # YAML null becomes Python None, which becomes string "None" when converted | |
| stride_val = final_params['stride'] | |
| stride_str = str(stride_val).strip().lower() | |
| if stride_str not in ('null', 'none', ''): | |
| cmd.extend(["--stride", str(stride_val)]) | |
| if 'irregular_patch' in final_params and final_params['irregular_patch']: | |
| cmd.append("--irregular_patch") | |
| if 'reverse_steps' in final_params: | |
| cmd.extend(["--reverse_steps", str(final_params['reverse_steps'])]) | |
| if 'batch_size' in final_params: | |
| cmd.extend(["--batch_size", str(final_params['batch_size'])]) | |
| if 'batch_num' in final_params: | |
| cmd.extend(["--batch_num", str(final_params['batch_num'])]) | |
| if 'num_workers' in final_params: | |
| cmd.extend(["--num_workers", str(final_params['num_workers'])]) | |
| if 'split' in final_params: | |
| cmd.extend(["--split", final_params['split']]) | |
| # Anomaly detection parameters | |
| if 'anomaly_binary_threshold' in final_params: | |
| cmd.extend(["--anomaly_binary_threshold", str(final_params['anomaly_binary_threshold'])]) | |
| if 'anomaly_pixel_num_threshold' in final_params: | |
| cmd.extend(["--anomaly_pixel_num_threshold", str(final_params['anomaly_pixel_num_threshold'])]) | |
| if 'adaptive_threshold' in final_params: | |
| cmd.extend(["--adaptive_threshold", str(final_params['adaptive_threshold'])]) | |
| # Output parameters | |
| if 'results_dir' in final_params: | |
| cmd.extend(["--results_dir", final_params['results_dir']]) | |
| if 'tag' in final_params: | |
| cmd.extend(["--tag", final_params['tag']]) | |
| if 'enable_excel_report' in final_params and final_params['enable_excel_report']: | |
| cmd.append("--enable_excel_report") | |
| if 'enable_save_image_results' in final_params and final_params['enable_save_image_results']: | |
| cmd.append("--enable_save_image_results") | |
| if 'enable_save_whole_image_results' in final_params and final_params['enable_save_whole_image_results']: | |
| cmd.append("--enable_save_whole_image_results") | |
| if 'save_reconstructions' in final_params and final_params['save_reconstructions']: | |
| cmd.append("--save_reconstructions") | |
| if 'save_reconstructions_npy' in final_params and final_params['save_reconstructions_npy']: | |
| cmd.append("--save_reconstructions_npy") | |
| if 'save_raw_anomaly_maps' in final_params and final_params['save_raw_anomaly_maps']: | |
| cmd.append("--save_raw_anomaly_maps") | |
| if 'save_image_variants' in final_params and final_params['save_image_variants']: | |
| cmd.extend(["--save-image-variants"] + [str(v) for v in final_params['save_image_variants']]) | |
| if 'save_colormaps' in final_params and final_params['save_colormaps']: | |
| cmd.extend(["--save-colormaps"] + [str(v) for v in final_params['save_colormaps']]) | |
| if 'save_normalization' in final_params and final_params['save_normalization']: | |
| cmd.extend(["--save-normalization", str(final_params['save_normalization'])]) | |
| if 'normal_center' in final_params and final_params['normal_center'] is not None: | |
| cmd.extend(["--normal-center", str(final_params['normal_center'])]) | |
| if 'save_binary_thresholds' in final_params and final_params['save_binary_thresholds']: | |
| cmd.extend(["--save-binary-thresholds"] + [str(v) for v in final_params['save_binary_thresholds']]) | |
| if 'enable_confusion_matrix' in final_params and final_params['enable_confusion_matrix']: | |
| cmd.append("--enable_confusion_matrix") | |
| # Filename strategy | |
| if 'filename_strategy' in final_params: | |
| cmd.extend(["--filename_strategy", final_params['filename_strategy']]) | |
| if 'regression_test_mode' in final_params and final_params['regression_test_mode']: | |
| cmd.append("--regression_test_mode") | |
| # Bootstrap parameters | |
| if 'bootstrap_samples' in final_params and final_params['bootstrap_samples'] is not None: | |
| cmd.extend(["--bootstrap_samples", str(final_params['bootstrap_samples'])]) | |
| if 'bootstrap_seed' in final_params: | |
| cmd.extend(["--bootstrap_seed", str(final_params['bootstrap_seed'])]) | |
| if 'bootstrap_sampling_trial_number' in final_params: | |
| cmd.extend(["--bootstrap_sampling_trial_number", str(final_params['bootstrap_sampling_trial_number'])]) | |
| # 7. Show standalone alternative | |
| show_standalone_alternatives(cmd, f"Evaluation Processing ({eval_mode}) on {eval_dataset}") | |
| # 8. Execute | |
| if dry_run: | |
| console.print("[yellow]Dry run - not executing[/yellow]") | |
| return | |
| console.print("[bold]Running:[/bold]", " ".join(cmd)) | |
| console.print() | |
| result = subprocess.run(cmd) | |
| sys.exit(result.returncode) | |
| @main.command('reconstruct') | |
| @click.option('--output-dir', type=click.Path(), required=True, | |
| help='Output directory for results') | |
| @click.option('--model-name', help='Model name') | |
| @click.option('--vqvae-checkpoint', type=click.Path(exists=True), | |
| help='Path to VQ-VAE checkpoint for LDM') | |
| @click.option('--num-inference-steps', type=int, default=100, | |
| help='Number of diffusion inference steps') | |
| @click.option('--batch-size', type=int, default=256, | |
| help='Batch size for reconstruction') | |
| @click.option('--dry-run/--no-dry-run', default=False, | |
| help='Show commands without executing') | |
| def reconstruct(output_dir, model_name, vqvae_checkpoint, num_inference_steps, | |
| batch_size, dry_run): | |
| """Reconstruct images using diffusion model. | |
| Orchestrates dioodmi-reconstruct with appropriate configuration. | |
| Examples: | |
| hae reconstruct --output-dir results/ --model-name my_model | |
| hae reconstruct --output-dir results/ --vqvae-checkpoint models/vqvae.pth | |
| """ | |
| console.print(f"[bold cyan]🔄 Reconstructing images[/bold cyan]\n") | |
| # Build command | |
| cmd = ["dioodmi-reconstruct", "--output_dir", output_dir] | |
| if model_name: | |
| cmd.extend(["--model_name", model_name]) | |
| if vqvae_checkpoint: | |
| cmd.extend(["--vqvae_checkpoint", vqvae_checkpoint]) | |
| if num_inference_steps: | |
| cmd.extend(["--num_inference_steps", str(num_inference_steps)]) | |
| if batch_size: | |
| cmd.extend(["--batch_size", str(batch_size)]) | |
| # Show standalone alternatives | |
| show_standalone_alternatives(cmd, "Reconstruct images with diffusion") | |
| if dry_run: | |
| console.print("[yellow]Dry run - not executing[/yellow]") | |
| return | |
| # Execute command | |
| console.print("[bold]Running:[/bold]", " ".join(cmd)) | |
| console.print() | |
| result = subprocess.run(cmd) | |
| sys.exit(result.returncode) | |
| class BenchmarkPreprocessingCommand(click.Command): | |
| """Command that preprocesses --latent_sizes arguments to handle space-separated values.""" | |
| def make_context(self, info_name, args, parent=None, **extra): | |
| args = preprocess_latent_sizes_args(list(args)) | |
| return super().make_context(info_name, args, parent=parent, **extra) | |
| @main.command('benchmark', cls=BenchmarkPreprocessingCommand) | |
| @click.option('--dataset', type=click.Choice(['mvtec', 'visa']), required=False, | |
| help='Dataset to benchmark (for evaluation benchmarking)') | |
| @click.option('--models', multiple=True, | |
| help='Models to benchmark (can specify multiple, for evaluation benchmarking)') | |
| @click.option('--categories', multiple=True, | |
| help='Categories to benchmark (default: all, for evaluation benchmarking)') | |
| @click.option('--metrics', multiple=True, default=['auroc'], | |
| help='Metrics to measure (for evaluation benchmarking)') | |
| @click.option('--output', type=click.Path(), default='benchmarks/', | |
| help='Benchmark results output directory') | |
| # Inference speed benchmarking options | |
| @click.option('--model', type=str, | |
| help='Model to benchmark for inference speed (UNet_L, TSFormer_L, WaveMamba_L, UNetFaster_L, compare)') | |
| @click.option('--latent_sizes', multiple=True, type=int, | |
| help='Latent sizes to test (can specify space-separated, e.g., --latent_sizes 16 32 64)') | |
| @click.option('--ncls', type=int, default=15, | |
| help='Number of classes (default: 15)') | |
| @click.option('--batch_size', type=int, default=4, | |
| help='Batch size (default: 4)') | |
| @click.option('--num_warmup', type=int, default=10, | |
| help='Number of warmup iterations (default: 10)') | |
| @click.option('--num_iterations', type=int, default=100, | |
| help='Number of benchmark iterations (default: 100)') | |
| @click.option('--output_dir', type=click.Path(), | |
| help='Output directory for inference speed benchmark results') | |
| def benchmark(dataset, models, categories, metrics, output, model, latent_sizes, ncls, batch_size, num_warmup, num_iterations, output_dir): | |
| """Run performance benchmarks on multiple models. | |
| Supports two modes: | |
| 1. Evaluation benchmarking: Compare models on datasets (--dataset required) | |
| 2. Inference speed benchmarking: Measure inference speed across different sizes (--model required) | |
| Examples: | |
| # Evaluation benchmarking | |
| hae benchmark --dataset mvtec --models ddpm --models decodiff | |
| hae benchmark --dataset mvtec --categories bottle --categories cable | |
| # Inference speed benchmarking | |
| hae benchmark --model UNet_L --latent_sizes 16 32 64 128 --ncls 15 --batch_size 4 --output_dir ./results | |
| hae benchmark --model compare --output_dir ./results | |
| """ | |
| # Determine which mode to use | |
| if model is not None: | |
| # Inference speed benchmarking mode | |
| _run_inference_speed_benchmark(model, latent_sizes, ncls, batch_size, num_warmup, num_iterations, output_dir or output) | |
| elif dataset is not None: | |
| # Evaluation benchmarking mode | |
| console.print(f"[bold yellow]⚡ Benchmarking on {dataset}[/bold yellow]\n") | |
| if models: | |
| console.print(f"Models: {', '.join(models)}") | |
| if categories: | |
| console.print(f"Categories: {', '.join(categories)}") | |
| console.print(f"Metrics: {', '.join(metrics)}") | |
| # TODO: Implement benchmarking workflow | |
| console.print("\n[red]Evaluation benchmarking workflow not yet implemented in HAE[/red]") | |
| console.print("This will orchestrate multiple dioodmi-eval runs and aggregate results") | |
| else: | |
| console.print("[red]Error: Either --dataset (for evaluation benchmarking) or --model (for inference speed benchmarking) must be specified[/red]") | |
| console.print("\nUse --help for usage information") | |
| sys.exit(1) | |
| def _run_inference_speed_benchmark(model, latent_sizes, ncls, batch_size, num_warmup, num_iterations, output_dir): | |
| """Run inference speed benchmarking.""" | |
| from pathlib import Path | |
| console.print(f"[bold yellow]⚡ Inference Speed Benchmarking[/bold yellow]\n") | |
| console.print(f"Model: {model}") | |
| if latent_sizes: | |
| console.print(f"Latent sizes: {', '.join(map(str, latent_sizes))}") | |
| console.print(f"Classes: {ncls}") | |
| console.print(f"Batch size: {batch_size}") | |
| console.print(f"Warmup iterations: {num_warmup}") | |
| console.print(f"Benchmark iterations: {num_iterations}") | |
| console.print(f"Output directory: {output_dir}\n") | |
| # Find the benchmark script | |
| # Try multiple paths: relative to current dir, then relative to hae-py location | |
| script_name = "benchmark_inference_speed_multi_size.py" | |
| script_path = None | |
| # Try relative to current working directory | |
| cwd_script = Path.cwd() / "experiments" / "definitions" / "scripts" / script_name | |
| if cwd_script.exists(): | |
| script_path = cwd_script | |
| # Try relative to hae-py source location (go up from hae/cli.py to project root) | |
| if script_path is None: | |
| hae_py_root = Path(__file__).parent.parent.parent.parent | |
| hae_script = hae_py_root / "experiments" / "definitions" / "scripts" / script_name | |
| if hae_script.exists(): | |
| script_path = hae_script | |
| # Try absolute path from current directory | |
| if script_path is None: | |
| abs_script = Path("definitions/scripts") / script_name | |
| if abs_script.exists(): | |
| script_path = abs_script.resolve() | |
| if script_path is None or not script_path.exists(): | |
| console.print(f"[red]Error: Benchmark script not found[/red]") | |
| console.print(f"Looked for: {script_name}") | |
| console.print(f" - {cwd_script}") | |
| console.print(f" - {hae_script if 'hae_script' in locals() else 'N/A'}") | |
| console.print("\nMake sure you're running from the project root directory") | |
| sys.exit(1) | |
| # Build command arguments (without Python interpreter for display) | |
| cmd_args = [str(script_path), "--model", model] | |
| if latent_sizes: | |
| # Convert tuple to list of strings for the script (which uses nargs='+') | |
| cmd_args.extend(["--latent_sizes"] + [str(s) for s in latent_sizes]) | |
| cmd_args.extend([ | |
| "--ncls", str(ncls), | |
| "--batch_size", str(batch_size), | |
| "--num_warmup", str(num_warmup), | |
| "--num_iterations", str(num_iterations), | |
| "--output_dir", str(output_dir) | |
| ]) | |
| # Show standalone alternatives (display only, for user reference) | |
| console.print("[bold]💡 Standalone alternatives:[/bold]") | |
| console.print(f" Direct: {' '.join(cmd_args)}") | |
| console.print(f" Python: python3 {' '.join(cmd_args)}") | |
| console.print() | |
| # Execute with sys.executable for cross-platform compatibility | |
| # This works on both Windows (where .py files can't be executed directly) | |
| # and NixOS/Unix (ensures using the same Python environment) | |
| exec_cmd = [sys.executable] + cmd_args | |
| console.print(f"[bold]Executing benchmark...[/bold]\n") | |
| try: | |
| result = subprocess.run(exec_cmd, check=True) | |
| console.print(f"\n[green]✓ Benchmark completed successfully[/green]") | |
| console.print(f"Results saved to: {output_dir}") | |
| except subprocess.CalledProcessError as e: | |
| console.print(f"\n[red]✗ Benchmark failed with exit code {e.returncode}[/red]") | |
| sys.exit(e.returncode) | |
| except FileNotFoundError: | |
| console.print(f"\n[red]✗ Benchmark script not found: {script_path}[/red]") | |
| console.print("Make sure you're running from the project root directory") | |
| sys.exit(1) | |
| @main.group() | |
| def examples(): | |
| """Manage and run example configurations. | |
| Examples provide pre-configured workflows for common use cases. | |
| """ | |
| pass | |
| @examples.command('list') | |
| def examples_list(): | |
| """List available examples.""" | |
| console.print("[bold]📚 Available Examples[/bold]\n") | |
| # TODO: Load from examples/ directory | |
| table = Table(title="Example Configurations") | |
| table.add_column("Name", style="cyan") | |
| table.add_column("Description", style="green") | |
| table.add_column("Dataset", style="yellow") | |
| # Placeholder examples | |
| table.add_row("mvtec-bottle", "Train DeCo-Diff on MVTec bottle", "MVTec-AD") | |
| table.add_row("visa-candle", "Train DeCo-Diff on VisA candle", "VisA") | |
| table.add_row("quick-test", "Quick test run (1 epoch)", "MVTec-AD") | |
| console.print(table) | |
| console.print("\n[dim]Use 'hae examples run <name>' to execute an example[/dim]") | |
| @examples.command('info') | |
| @click.argument('name') | |
| def examples_info(name): | |
| """Show details about an example.""" | |
| console.print(f"[bold]📖 Example: {name}[/bold]\n") | |
| # TODO: Load from examples/ directory | |
| console.print("[red]Example info not yet implemented[/red]") | |
| console.print(f"Will show details about example '{name}'") | |
| @examples.command('run') | |
| @click.argument('name') | |
| @click.option('--train/--no-train', default=False, help='Run training') | |
| @click.option('--eval/--no-eval', default=False, help='Run evaluation') | |
| @click.option('--dry-run/--no-dry-run', default=False, | |
| help='Show commands without executing') | |
| def examples_run(name, train, eval, dry_run): | |
| """Run an example configuration. | |
| Examples: | |
| hae examples run mvtec-bottle --train | |
| hae examples run mvtec-bottle --eval | |
| hae examples run quick-test --train --eval | |
| """ | |
| console.print(f"[bold]🎯 Running Example: {name}[/bold]\n") | |
| operations = [] | |
| if train: | |
| operations.append("Train") | |
| if eval: | |
| operations.append("Eval") | |
| if not operations: | |
| console.print("[red]Error: Specify at least one operation (--train or --eval)[/red]") | |
| sys.exit(1) | |
| # TODO: Load example config and execute | |
| console.print("[red]Example execution not yet implemented[/red]") | |
| console.print(f"Will execute: {', '.join(operations)}") | |
| @examples.command('export') | |
| @click.argument('name') | |
| @click.option('--output', type=click.Path(), default='.', | |
| help='Directory to export configs to') | |
| def examples_export(name, output): | |
| """Export example configs to a directory. | |
| Examples: | |
| hae examples export mvtec-bottle | |
| hae examples export mvtec-bottle --output my-experiment/ | |
| """ | |
| console.print(f"[bold]📤 Exporting Example: {name}[/bold]\n") | |
| # TODO: Copy example configs | |
| console.print("[red]Example export not yet implemented[/red]") | |
| console.print(f"Will export configs to: {output}") | |
| @main.group() | |
| def dev(): | |
| """Development tools and utilities. | |
| Provides interactive tools, code quality checks, and utilities for DiOodMi development. | |
| """ | |
| pass | |
| @dev.command('shell') | |
| def dev_shell(): | |
| """Start interactive Python shell with DiOodMi loaded. | |
| Example: | |
| hae dev shell | |
| """ | |
| console.print("[bold cyan]🐍 Starting Python shell with DiOodMi[/bold cyan]\n") | |
| # Build command for IPython with dioodmi preloaded | |
| cmd = ["python", "-m", "IPython", "-i", "-c", | |
| "import dioodmi; from dioodmi.trainers import *; from dioodmi.networks import *; print('DiOodMi loaded')"] | |
| show_standalone_alternatives(["python", "-m", "IPython"], "Interactive Python shell") | |
| result = subprocess.run(cmd) | |
| sys.exit(result.returncode) | |
| @dev.command('jupyter') | |
| @click.option('--port', type=int, default=8888, help='Port for Jupyter server') | |
| @click.option('--no-browser', is_flag=True, help='Don\'t open browser automatically') | |
| def dev_jupyter(port, no_browser): | |
| """Launch Jupyter Lab for interactive development. | |
| Example: | |
| hae dev jupyter | |
| hae dev jupyter --port 9000 --no-browser | |
| """ | |
| console.print(f"[bold cyan]📓 Launching Jupyter Lab on port {port}[/bold cyan]\n") | |
| cmd = ["jupyter", "lab", f"--port={port}"] | |
| if no_browser: | |
| cmd.append("--no-browser") | |
| show_standalone_alternatives(cmd[:2], f"Jupyter Lab on port {port}") | |
| result = subprocess.run(cmd) | |
| sys.exit(result.returncode) | |
| @dev.command('format') | |
| @click.option('--check', is_flag=True, help='Check only, don\'t modify files') | |
| @click.option('--path', type=click.Path(exists=True), default='.', | |
| help='Path to format') | |
| def dev_format(check, path): | |
| """Format code with ruff. | |
| Example: | |
| hae dev format | |
| hae dev format --check | |
| """ | |
| console.print("[bold green]✨ Formatting code with ruff[/bold green]\n") | |
| cmd = ["ruff", "format"] | |
| if check: | |
| cmd.append("--check") | |
| cmd.append(path) | |
| show_standalone_alternatives(cmd, "Format code with ruff") | |
| result = subprocess.run(cmd) | |
| if result.returncode == 0: | |
| console.print("\n[green]✓ Formatting complete[/green]") | |
| sys.exit(result.returncode) | |
| @dev.command('lint') | |
| @click.option('--fix', is_flag=True, help='Automatically fix issues') | |
| @click.option('--path', type=click.Path(exists=True), default='.', | |
| help='Path to lint') | |
| def dev_lint(fix, path): | |
| """Lint code with ruff. | |
| Example: | |
| hae dev lint | |
| hae dev lint --fix | |
| """ | |
| console.print("[bold yellow]🔍 Linting code with ruff[/bold yellow]\n") | |
| cmd = ["ruff", "check"] | |
| if fix: | |
| cmd.append("--fix") | |
| cmd.append(path) | |
| show_standalone_alternatives(cmd, "Lint code with ruff") | |
| result = subprocess.run(cmd) | |
| if result.returncode == 0: | |
| console.print("\n[green]✓ No linting issues found[/green]") | |
| sys.exit(result.returncode) | |
| @dev.command('test') | |
| @click.option('--path', type=click.Path(exists=True), default='tests/', | |
| help='Test path or file') | |
| @click.option('--coverage', is_flag=True, help='Run with coverage report') | |
| @click.option('--verbose', '-v', is_flag=True, help='Verbose output') | |
| @click.option('--keyword', '-k', help='Run tests matching keyword') | |
| def dev_test(path, coverage, verbose, keyword): | |
| """Run tests with pytest. | |
| Example: | |
| hae dev test | |
| hae dev test --coverage | |
| hae dev test -k "test_decodiff" | |
| """ | |
| console.print("[bold blue]🧪 Running tests with pytest[/bold blue]\n") | |
| cmd = ["pytest"] | |
| if coverage: | |
| cmd.extend(["--cov=dioodmi", "--cov-report=term-missing"]) | |
| if verbose: | |
| cmd.append("-v") | |
| if keyword: | |
| cmd.extend(["-k", keyword]) | |
| cmd.append(path) | |
| show_standalone_alternatives(cmd[:1], "Run pytest tests") | |
| result = subprocess.run(cmd) | |
| sys.exit(result.returncode) | |
| @dev.command('typecheck') | |
| @click.option('--path', type=click.Path(exists=True), default='.', | |
| help='Path to type check') | |
| def dev_typecheck(path): | |
| """Run mypy type checking. | |
| Example: | |
| hae dev typecheck | |
| """ | |
| console.print("[bold magenta]🔎 Type checking with mypy[/bold magenta]\n") | |
| cmd = ["mypy", path] | |
| show_standalone_alternatives(cmd, "Type check with mypy") | |
| result = subprocess.run(cmd) | |
| if result.returncode == 0: | |
| console.print("\n[green]✓ Type checking passed[/green]") | |
| sys.exit(result.returncode) | |
| @dev.command('clean') | |
| @click.option('--deep', is_flag=True, help='Also remove .venv, build artifacts') | |
| def dev_clean(deep): | |
| """Clean build artifacts and caches. | |
| Example: | |
| hae dev clean | |
| hae dev clean --deep | |
| """ | |
| console.print("[bold red]🧹 Cleaning build artifacts[/bold red]\n") | |
| import shutil | |
| from pathlib import Path | |
| patterns = [ | |
| "**/__pycache__", | |
| "**/*.pyc", | |
| "**/*.pyo", | |
| "**/.pytest_cache", | |
| "**/.ruff_cache", | |
| "**/.mypy_cache", | |
| "**/htmlcov", | |
| "**/.coverage", | |
| ] | |
| if deep: | |
| patterns.extend([ | |
| "**/.venv", | |
| "**/build", | |
| "**/dist", | |
| "**/*.egg-info", | |
| ]) | |
| removed_count = 0 | |
| for pattern in patterns: | |
| for path in Path(".").glob(pattern): | |
| if path.exists(): | |
| console.print(f" Removing: {path}") | |
| if path.is_dir(): | |
| shutil.rmtree(path) | |
| else: | |
| path.unlink() | |
| removed_count += 1 | |
| console.print(f"\n[green]✓ Cleaned {removed_count} items[/green]") | |
| @dev.command('docs') | |
| @click.option('--serve', is_flag=True, help='Serve docs locally') | |
| @click.option('--port', type=int, default=8000, help='Port for doc server') | |
| def dev_docs(serve, port): | |
| """Generate documentation. | |
| Example: | |
| hae dev docs | |
| hae dev docs --serve | |
| """ | |
| console.print("[bold cyan]📚 Generating documentation[/bold cyan]\n") | |
| # TODO: Implement documentation generation | |
| console.print("[yellow]Documentation generation not yet implemented[/yellow]") | |
| console.print("Will use Sphinx or mkdocs to generate from docstrings") | |
| @dev.command('check-deps') | |
| def dev_check_deps(): | |
| """Check dependencies for updates and security issues. | |
| Example: | |
| hae dev check-deps | |
| """ | |
| console.print("[bold yellow]📦 Checking dependencies[/bold yellow]\n") | |
| # Check for outdated packages | |
| console.print("[cyan]Checking for updates...[/cyan]") | |
| cmd = ["uv", "pip", "list", "--outdated"] | |
| show_standalone_alternatives(cmd, "Check outdated packages") | |
| subprocess.run(cmd) | |
| console.print("\n[cyan]Security audit coming soon...[/cyan]") | |
| @dev.command('profile') | |
| @click.argument('script', type=click.Path(exists=True)) | |
| @click.option('--sort', default='cumulative', help='Sort order (cumulative, time, calls)') | |
| def dev_profile(script, sort): | |
| """Profile a Python script for performance analysis. | |
| Example: | |
| hae dev profile project/train_decodiff.py | |
| """ | |
| console.print(f"[bold magenta]⚡ Profiling {script}[/bold magenta]\n") | |
| cmd = ["python", "-m", "cProfile", "-s", sort, script] | |
| show_standalone_alternatives(cmd, f"Profile {script}") | |
| result = subprocess.run(cmd) | |
| sys.exit(result.returncode) | |
| @dev.command('check-cuda') | |
| @click.option('--dry-run/--no-dry-run', default=False, help='Show command without executing') | |
| def dev_check_cuda(dry_run): | |
| """Check CUDA availability and GPU information. | |
| Displays PyTorch CUDA version, cuDNN version, number of GPUs, | |
| and detailed information about each GPU (name, compute capability, memory). | |
| Example: | |
| hae dev check-cuda | |
| """ | |
| console.print("[bold green]🎮 Checking CUDA Setup[/bold green]\n") | |
| project_root = Path(__file__).parent.parent.parent.parent.parent | |
| script_path = project_root / "check-cuda.py" | |
| cmd = ["python", str(script_path)] | |
| show_standalone_alternatives(cmd, "Check CUDA availability and GPU info") | |
| if dry_run: | |
| console.print("[yellow]Dry run - not executing[/yellow]") | |
| return | |
| console.print("[bold]Running:[/bold]", " ".join(cmd)) | |
| console.print() | |
| result = subprocess.run(cmd) | |
| sys.exit(result.returncode) | |
| @dev.command('test-cufft') | |
| @click.option('--dry-run/--no-dry-run', default=False, help='Show command without executing') | |
| def dev_test_cufft(dry_run): | |
| """Test CUDA FFT functionality. | |
| Quick test to verify cuFFT (CUDA Fast Fourier Transform) is working | |
| correctly by running a simple FFT operation on GPU. | |
| Example: | |
| hae dev test-cufft | |
| """ | |
| console.print("[bold green]📊 Testing cuFFT[/bold green]\n") | |
| project_root = Path(__file__).parent.parent.parent.parent.parent | |
| script_path = project_root / "cufftTest.py" | |
| cmd = ["python", str(script_path)] | |
| show_standalone_alternatives(cmd, "Test CUDA FFT functionality") | |
| if dry_run: | |
| console.print("[yellow]Dry run - not executing[/yellow]") | |
| return | |
| console.print("[bold]Running:[/bold]", " ".join(cmd)) | |
| console.print() | |
| result = subprocess.run(cmd) | |
| sys.exit(result.returncode) | |
| @dev.command('sort-interactive') | |
| @click.option('--pred-jsonl', required=True, type=click.Path(exists=True), help='Prediction JSONL file') | |
| @click.option('--out-dir', required=True, type=click.Path(), help='Output directory for sorted images') | |
| @click.option('--zoom', type=float, default=1.0, help='Zoom factor for display') | |
| @click.option('--resume', is_flag=True, help='Resume from previous session') | |
| @click.option('--dry-run/--no-dry-run', default=False, help='Show command without executing') | |
| def dev_sort_interactive(pred_jsonl, out_dir, zoom, resume, dry_run): | |
| """Interactive image sorter for anomaly detection results. | |
| Opens an interactive GUI for manually sorting and categorizing | |
| anomaly detection results (good, false positive, missed, etc.). | |
| Examples: | |
| hae dev sort-interactive --pred-jsonl predictions.jsonl --out-dir ./sorted | |
| hae dev sort-interactive --pred-jsonl pred.jsonl --out-dir ./sorted --zoom 1.5 --resume | |
| """ | |
| console.print("[bold blue]🖱️ Interactive Sorter[/bold blue]\n") | |
| project_root = Path(__file__).parent.parent.parent.parent.parent | |
| script_path = project_root / "interactive_sorter.py" | |
| cmd = [ | |
| "python", str(script_path), | |
| "--pred_jsonl", pred_jsonl, | |
| "--out_dir", out_dir, | |
| "--zoom", str(zoom) | |
| ] | |
| if resume: | |
| cmd.append("--resume") | |
| show_standalone_alternatives(cmd, "Interactive image sorter") | |
| if dry_run: | |
| console.print("[yellow]Dry run - not executing[/yellow]") | |
| return | |
| console.print("[bold]Running:[/bold]", " ".join(cmd)) | |
| console.print() | |
| result = subprocess.run(cmd) | |
| sys.exit(result.returncode) | |
| @dev.command('test-evolving-patch') | |
| @click.argument('args', nargs=-1, type=click.UNPROCESSED) | |
| @click.option('--dry-run/--no-dry-run', default=False, help='Show command without executing') | |
| def dev_test_evolving_patch(args, dry_run): | |
| """Test evolving patch position encoding. | |
| Tests the position encoding mechanism for patches that evolve | |
| during training (dropout, augmentation, etc.). | |
| Example: | |
| hae dev test-evolving-patch | |
| hae dev test-evolving-patch --help | |
| """ | |
| console.print("[bold magenta]🧪 Testing Evolving Patch[/bold magenta]\n") | |
| project_root = Path(__file__).parent.parent.parent.parent.parent | |
| script_path = project_root / "scripts" / "test_evolving_patch.py" | |
| cmd = ["python", str(script_path)] + list(args) | |
| show_standalone_alternatives(cmd, "Test evolving patch position encoding") | |
| if dry_run: | |
| console.print("[yellow]Dry run - not executing[/yellow]") | |
| return | |
| console.print("[bold]Running:[/bold]", " ".join(cmd)) | |
| console.print() | |
| result = subprocess.run(cmd) | |
| sys.exit(result.returncode) | |
| @main.group() | |
| def dvc(): | |
| """DVC (Data Version Control) operations. | |
| Manage datasets, models, pipelines, and experiments with DVC. | |
| """ | |
| pass | |
| @dvc.command('status') | |
| @click.option('--cloud', is_flag=True, help='Show status of cloud storage') | |
| def dvc_status(cloud): | |
| """Show DVC status (tracked files, changes). | |
| Example: | |
| hae dvc status | |
| hae dvc status --cloud | |
| """ | |
| console.print("[bold cyan]📊 DVC Status[/bold cyan]\n") | |
| cmd = ["dvc", "status"] | |
| if cloud: | |
| cmd.append("--cloud") | |
| show_standalone_alternatives(cmd, "Check DVC status") | |
| result = subprocess.run(cmd) | |
| sys.exit(result.returncode) | |
| @dvc.command('add') | |
| @click.argument('path', type=click.Path(exists=True)) | |
| @click.option('--recursive', '-r', is_flag=True, help='Add directory recursively') | |
| def dvc_add(path, recursive): | |
| """Track file or directory with DVC. | |
| Example: | |
| hae dvc add data/mvtec-dataset/ | |
| hae dvc add models/checkpoint.pth | |
| """ | |
| console.print(f"[bold green]➕ Adding {path} to DVC[/bold green]\n") | |
| cmd = ["dvc", "add", path] | |
| if recursive: | |
| cmd.append("-r") | |
| show_standalone_alternatives(cmd, f"Track {path} with DVC") | |
| result = subprocess.run(cmd) | |
| if result.returncode == 0: | |
| console.print(f"\n[green]✓ Added {path} to DVC[/green]") | |
| console.print(f"[yellow]Don't forget to:[/yellow]") | |
| console.print(f" git add {path}.dvc .gitignore") | |
| console.print(f" git commit -m 'Add {path} to DVC'") | |
| sys.exit(result.returncode) | |
| @dvc.command('push') | |
| @click.option('--remote', '-r', help='Push to specific remote') | |
| @click.option('--all-branches', is_flag=True, help='Push all branches') | |
| @click.option('--all-tags', is_flag=True, help='Push all tags') | |
| def dvc_push(remote, all_branches, all_tags): | |
| """Push tracked data to remote storage. | |
| Example: | |
| hae dvc push | |
| hae dvc push --remote myremote | |
| """ | |
| console.print("[bold blue]⬆️ Pushing to DVC remote[/bold blue]\n") | |
| cmd = ["dvc", "push"] | |
| if remote: | |
| cmd.extend(["--remote", remote]) | |
| if all_branches: | |
| cmd.append("--all-branches") | |
| if all_tags: | |
| cmd.append("--all-tags") | |
| show_standalone_alternatives(cmd, "Push data to DVC remote") | |
| result = subprocess.run(cmd) | |
| if result.returncode == 0: | |
| console.print("\n[green]✓ Push complete[/green]") | |
| sys.exit(result.returncode) | |
| @dvc.command('pull') | |
| @click.option('--remote', '-r', help='Pull from specific remote') | |
| @click.option('--all-branches', is_flag=True, help='Pull all branches') | |
| @click.option('--all-tags', is_flag=True, help='Pull all tags') | |
| def dvc_pull(remote, all_branches, all_tags): | |
| """Pull tracked data from remote storage. | |
| Example: | |
| hae dvc pull | |
| hae dvc pull --remote myremote | |
| """ | |
| console.print("[bold blue]⬇️ Pulling from DVC remote[/bold blue]\n") | |
| cmd = ["dvc", "pull"] | |
| if remote: | |
| cmd.extend(["--remote", remote]) | |
| if all_branches: | |
| cmd.append("--all-branches") | |
| if all_tags: | |
| cmd.append("--all-tags") | |
| show_standalone_alternatives(cmd, "Pull data from DVC remote") | |
| result = subprocess.run(cmd) | |
| if result.returncode == 0: | |
| console.print("\n[green]✓ Pull complete[/green]") | |
| sys.exit(result.returncode) | |
| @dvc.command('checkout') | |
| @click.option('--force', is_flag=True, help='Force checkout even with conflicts') | |
| def dvc_checkout(force): | |
| """Checkout data files tracked by DVC. | |
| Example: | |
| hae dvc checkout | |
| hae dvc checkout --force | |
| """ | |
| console.print("[bold cyan]🔄 Checking out DVC files[/bold cyan]\n") | |
| cmd = ["dvc", "checkout"] | |
| if force: | |
| cmd.append("--force") | |
| show_standalone_alternatives(cmd, "Checkout DVC tracked files") | |
| result = subprocess.run(cmd) | |
| if result.returncode == 0: | |
| console.print("\n[green]✓ Checkout complete[/green]") | |
| sys.exit(result.returncode) | |
| @dvc.command('repro') | |
| @click.option('--force', is_flag=True, help='Reproduce even if dependencies unchanged') | |
| @click.option('--downstream', is_flag=True, help='Reproduce downstream stages') | |
| @click.argument('target', required=False) | |
| def dvc_repro(force, downstream, target): | |
| """Reproduce DVC pipeline. | |
| Example: | |
| hae dvc repro | |
| hae dvc repro train | |
| hae dvc repro --force --downstream | |
| """ | |
| console.print("[bold magenta]🔁 Reproducing DVC pipeline[/bold magenta]\n") | |
| cmd = ["dvc", "repro"] | |
| if force: | |
| cmd.append("--force") | |
| if downstream: | |
| cmd.append("--downstream") | |
| if target: | |
| cmd.append(target) | |
| show_standalone_alternatives(cmd, "Reproduce DVC pipeline") | |
| result = subprocess.run(cmd) | |
| if result.returncode == 0: | |
| console.print("\n[green]✓ Pipeline reproduced[/green]") | |
| sys.exit(result.returncode) | |
| @dvc.command('run') | |
| @click.option('--name', '-n', required=True, help='Stage name') | |
| @click.option('--deps', '-d', multiple=True, help='Dependencies') | |
| @click.option('--outs', '-o', multiple=True, help='Outputs') | |
| @click.option('--metrics', '-m', multiple=True, help='Metrics files') | |
| @click.option('--params', '-p', multiple=True, help='Parameters') | |
| @click.argument('command', nargs=-1, required=True) | |
| def dvc_run(name, deps, outs, metrics, params, command): | |
| """Create a DVC pipeline stage. | |
| Example: | |
| hae dvc run -n train -d data/ -o models/ -m metrics.json -- python train.py | |
| """ | |
| console.print(f"[bold green]▶️ Creating stage '{name}'[/bold green]\n") | |
| cmd = ["dvc", "run", "-n", name] | |
| for dep in deps: | |
| cmd.extend(["-d", dep]) | |
| for out in outs: | |
| cmd.extend(["-o", out]) | |
| for metric in metrics: | |
| cmd.extend(["-m", metric]) | |
| for param in params: | |
| cmd.extend(["-p", param]) | |
| cmd.extend(list(command)) | |
| show_standalone_alternatives(cmd, f"Create DVC stage '{name}'") | |
| result = subprocess.run(cmd) | |
| if result.returncode == 0: | |
| console.print(f"\n[green]✓ Stage '{name}' created[/green]") | |
| sys.exit(result.returncode) | |
| @dvc.group('exp') | |
| def dvc_exp(): | |
| """DVC experiment tracking commands.""" | |
| pass | |
| @dvc_exp.command('run') | |
| @click.option('--name', '-n', help='Experiment name') | |
| @click.option('--queue', is_flag=True, help='Queue experiment instead of running') | |
| @click.option('--set-param', '-S', multiple=True, help='Override parameter (key=value)') | |
| def dvc_exp_run(name, queue, set_param): | |
| """Run a DVC experiment. | |
| Example: | |
| hae dvc exp run -n my_exp | |
| hae dvc exp run -S learning_rate=0.001 -S epochs=100 | |
| """ | |
| console.print("[bold magenta]🧪 Running DVC experiment[/bold magenta]\n") | |
| cmd = ["dvc", "exp", "run"] | |
| if name: | |
| cmd.extend(["--name", name]) | |
| if queue: | |
| cmd.append("--queue") | |
| for param in set_param: | |
| cmd.extend(["-S", param]) | |
| show_standalone_alternatives(cmd, "Run DVC experiment") | |
| result = subprocess.run(cmd) | |
| sys.exit(result.returncode) | |
| @dvc_exp.command('show') | |
| @click.option('--all-branches', is_flag=True, help='Show experiments from all branches') | |
| @click.option('--all-commits', is_flag=True, help='Show experiments from all commits') | |
| def dvc_exp_show(all_branches, all_commits): | |
| """Show experiments and their metrics. | |
| Example: | |
| hae dvc exp show | |
| hae dvc exp show --all-commits | |
| """ | |
| console.print("[bold cyan]📊 DVC Experiments[/bold cyan]\n") | |
| cmd = ["dvc", "exp", "show"] | |
| if all_branches: | |
| cmd.append("--all-branches") | |
| if all_commits: | |
| cmd.append("--all-commits") | |
| show_standalone_alternatives(cmd, "Show DVC experiments") | |
| result = subprocess.run(cmd) | |
| sys.exit(result.returncode) | |
| @dvc_exp.command('diff') | |
| @click.argument('experiment', required=False) | |
| @click.option('--all', is_flag=True, help='Show all metrics and params') | |
| def dvc_exp_diff(experiment, all): | |
| """Show difference between experiments. | |
| Example: | |
| hae dvc exp diff | |
| hae dvc exp diff exp-abc123 | |
| """ | |
| console.print("[bold yellow]🔍 Experiment Diff[/bold yellow]\n") | |
| cmd = ["dvc", "exp", "diff"] | |
| if experiment: | |
| cmd.append(experiment) | |
| if all: | |
| cmd.append("--all") | |
| show_standalone_alternatives(cmd, "Compare DVC experiments") | |
| result = subprocess.run(cmd) | |
| sys.exit(result.returncode) | |
| @dvc_exp.command('apply') | |
| @click.argument('experiment', required=True) | |
| def dvc_exp_apply(experiment): | |
| """Apply an experiment to workspace. | |
| Example: | |
| hae dvc exp apply exp-abc123 | |
| """ | |
| console.print(f"[bold green]✨ Applying experiment {experiment}[/bold green]\n") | |
| cmd = ["dvc", "exp", "apply", experiment] | |
| show_standalone_alternatives(cmd, f"Apply experiment {experiment}") | |
| result = subprocess.run(cmd) | |
| if result.returncode == 0: | |
| console.print(f"\n[green]✓ Experiment {experiment} applied[/green]") | |
| sys.exit(result.returncode) | |
| def human_size(bytes_size): | |
| """Convert bytes to human-readable format.""" | |
| if bytes_size is None: | |
| return "?" | |
| for unit in ['B', 'KB', 'MB', 'GB', 'TB']: | |
| if bytes_size < 1024.0: | |
| return f"{bytes_size:.1f}{unit}" | |
| bytes_size /= 1024.0 | |
| return f"{bytes_size:.1f}PB" | |
| def parse_dvc_files(filter_path=None): | |
| """Parse all .dvc files and return metadata.""" | |
| import yaml | |
| dvc_files = [] | |
| for dvc_file in Path('.').rglob('*.dvc'): | |
| # Skip .dvc directory itself | |
| if dvc_file.is_dir(): | |
| continue | |
| if filter_path and filter_path not in str(dvc_file): | |
| continue | |
| try: | |
| with open(dvc_file) as f: | |
| content = yaml.safe_load(f) | |
| if 'outs' in content and len(content['outs']) > 0: | |
| out = content['outs'][0] | |
| data_path = dvc_file.parent / out.get('path', '') | |
| dvc_files.append({ | |
| 'dvc_file': str(dvc_file), | |
| 'data_path': str(data_path), | |
| 'size': out.get('size'), | |
| 'nfiles': out.get('nfiles', 1), | |
| 'exists': data_path.exists() | |
| }) | |
| except Exception as e: | |
| console.print(f"[yellow]Warning: Could not parse {dvc_file}: {e}[/yellow]") | |
| return sorted(dvc_files, key=lambda x: x['dvc_file']) | |
| def interactive_dvc_selection(dvc_files): | |
| """Interactive selection of DVC files for pull/push.""" | |
| from InquirerPy import inquirer | |
| # Create choices with status indicators | |
| choices = [] | |
| for item in dvc_files: | |
| status = "✓" if item['exists'] else "✗" | |
| size = human_size(item['size']) | |
| label = f"{status} {item['dvc_file']:60s} ({size})" | |
| choices.append({ | |
| 'name': label, | |
| 'value': item['dvc_file'], | |
| 'enabled': False | |
| }) | |
| # Multi-select checkbox | |
| selected = inquirer.checkbox( | |
| message="Select datasets (space to toggle, enter to confirm):", | |
| choices=choices, | |
| validate=lambda x: len(x) > 0 or "Please select at least one dataset" | |
| ).execute() | |
| if not selected: | |
| return None, None | |
| # Ask for action | |
| action = inquirer.select( | |
| message="What would you like to do?", | |
| choices=[ | |
| {'name': 'Pull (download from remote)', 'value': 'pull'}, | |
| {'name': 'Push (upload to remote)', 'value': 'push'}, | |
| {'name': 'Cancel', 'value': 'cancel'} | |
| ] | |
| ).execute() | |
| return selected, action | |
| @dvc.command('list') | |
| @click.option('--select', '-s', is_flag=True, help='Interactive selection for pull/push') | |
| @click.option('--filter', help='Filter by path (datasets, fixtures, golden)') | |
| @click.option('--remote-only', is_flag=True, help='Show only remote datasets') | |
| @click.option('--local-only', is_flag=True, help='Show only local datasets') | |
| @click.option('--recursive', '-R', is_flag=True, help='(Compatibility - always recursive)') | |
| @click.option('--tree', '-T', is_flag=True, help='(Compatibility - ignored)') | |
| @click.option('--size', is_flag=True, help='(Compatibility - always shown)') | |
| def dvc_list_files(select, filter, remote_only, local_only, recursive, tree, size): | |
| """List all DVC-tracked datasets with detailed information. | |
| Examples: | |
| hae dvc list # Show all datasets | |
| hae dvc list --select # Interactive pull/push | |
| hae dvc list --filter datasets # Filter by path | |
| hae dvc list --remote-only # Show only remote datasets | |
| """ | |
| console.print("[bold cyan]📁 DVC Tracked Datasets[/bold cyan]\n") | |
| # Parse all .dvc files | |
| dvc_files = parse_dvc_files(filter_path=filter) | |
| # Apply filters | |
| if remote_only: | |
| dvc_files = [f for f in dvc_files if not f['exists']] | |
| if local_only: | |
| dvc_files = [f for f in dvc_files if f['exists']] | |
| if not dvc_files: | |
| console.print("[yellow]No DVC files found matching criteria[/yellow]") | |
| return | |
| # Create Rich table | |
| table = Table(show_header=True, header_style="bold magenta") | |
| table.add_column("Status", style="cyan", width=12) | |
| table.add_column("Size", style="yellow", width=10) | |
| table.add_column("Files", style="green", width=10) | |
| table.add_column("DVC File", style="white") | |
| for item in dvc_files: | |
| status = "[green]✓ LOCAL[/green]" if item['exists'] else "[red]✗ REMOTE[/red]" | |
| size_str = human_size(item['size']) | |
| nfiles = str(item['nfiles']) if item['nfiles'] > 1 else "-" | |
| table.add_row(status, size_str, nfiles, item['dvc_file']) | |
| console.print(table) | |
| console.print(f"\n[bold]Total:[/bold] {len(dvc_files)} DVC-tracked datasets") | |
| # Interactive selection mode | |
| if select: | |
| console.print() | |
| selected_files, action = interactive_dvc_selection(dvc_files) | |
| if action == 'cancel' or not selected_files: | |
| console.print("[yellow]Operation cancelled[/yellow]") | |
| return | |
| # Execute DVC command | |
| cmd = ["dvc", action] + selected_files | |
| console.print(f"\n[cyan]Running:[/cyan] {' '.join(cmd)}\n") | |
| result = subprocess.run(cmd) | |
| if result.returncode == 0: | |
| console.print(f"\n[green]✓ Successfully {action}ed {len(selected_files)} dataset(s)[/green]") | |
| else: | |
| console.print(f"\n[red]✗ {action} failed with code {result.returncode}[/red]") | |
| sys.exit(result.returncode) | |
| else: | |
| # Show pull command hint for remote files | |
| remote_count = sum(1 for f in dvc_files if not f['exists']) | |
| if remote_count > 0: | |
| console.print(f"\n[yellow]💡 {remote_count} datasets are remote. Use --select for interactive pull:[/yellow]") | |
| console.print(" hae dvc list --select") | |
| @dvc.command('dag') | |
| def dvc_dag(): | |
| """Show DVC pipeline DAG. | |
| Example: | |
| hae dvc dag | |
| """ | |
| console.print("[bold cyan]🔀 DVC Pipeline DAG[/bold cyan]\n") | |
| cmd = ["dvc", "dag"] | |
| show_standalone_alternatives(cmd, "Show pipeline DAG") | |
| result = subprocess.run(cmd) | |
| sys.exit(result.returncode) | |
| @dvc.command('metrics') | |
| @click.option('--show-md', is_flag=True, help='Show as markdown table') | |
| @click.option('--all-branches', is_flag=True, help='Show metrics from all branches') | |
| def dvc_metrics(show_md, all_branches): | |
| """Show metrics from tracked files. | |
| Example: | |
| hae dvc metrics | |
| hae dvc metrics --show-md | |
| """ | |
| console.print("[bold magenta]📈 DVC Metrics[/bold magenta]\n") | |
| cmd = ["dvc", "metrics", "show"] | |
| if show_md: | |
| cmd.append("--show-md") | |
| if all_branches: | |
| cmd.append("--all-branches") | |
| show_standalone_alternatives(cmd, "Show DVC metrics") | |
| result = subprocess.run(cmd) | |
| sys.exit(result.returncode) | |
| @dvc.command('params') | |
| @click.option('--all-branches', is_flag=True, help='Show params from all branches') | |
| def dvc_params(all_branches): | |
| """Show parameters from tracked files. | |
| Example: | |
| hae dvc params | |
| """ | |
| console.print("[bold yellow]⚙️ DVC Parameters[/bold yellow]\n") | |
| cmd = ["dvc", "params", "show"] | |
| if all_branches: | |
| cmd.append("--all-branches") | |
| show_standalone_alternatives(cmd, "Show DVC parameters") | |
| result = subprocess.run(cmd) | |
| sys.exit(result.returncode) | |
| # ============================================================================ | |
| # OOD Detection Commands | |
| # ============================================================================ | |
| @main.command('ood-detection') | |
| @click.option('--output-dir', required=True, type=click.Path(), | |
| help='Location for models/results') | |
| @click.option('--model-name', required=True, | |
| help='Name of model to analyze') | |
| @click.option('--max-t', type=int, default=1000, | |
| help='Maximum T to consider reconstructions from') | |
| @click.option('--min-t', type=int, default=0, | |
| help='Minimum T to consider reconstructions from') | |
| @click.option('--t-skip', type=int, default=1, | |
| help='Only use every n reconstructions') | |
| @click.option('--seed', type=int, default=2, | |
| help='Random seed to use') | |
| def ood_detection(output_dir, model_name, max_t, min_t, t_skip, seed): | |
| """Perform OOD (Out-of-Distribution) detection analysis. | |
| Analyzes reconstruction errors at different timesteps to detect OOD samples. | |
| Example: | |
| hae ood-detection --output-dir ./results --model-name decodiff_mvtec | |
| hae ood-detection --output-dir ./results --model-name model_name --max-t 500 | |
| """ | |
| console.print("[bold cyan]🔍 OOD Detection Analysis[/bold cyan]\n") | |
| cmd = [ | |
| "python", "ood_detection.py", | |
| "--output_dir", str(output_dir), | |
| "--model_name", model_name, | |
| "--max_t", str(max_t), | |
| "--min_t", str(min_t), | |
| "--t_skip", str(t_skip), | |
| "--seed", str(seed) | |
| ] | |
| show_standalone_alternatives(cmd, "Perform OOD detection analysis") | |
| console.print("[bold]Running:[/bold]", " ".join(cmd)) | |
| console.print() | |
| result = subprocess.run(cmd) | |
| sys.exit(result.returncode) | |
| @main.command('ood-results') | |
| @click.option('--output-dir', required=True, type=click.Path(), | |
| help='Location for models/results') | |
| @click.option('--model-name', required=True, | |
| help='Name of model to analyze') | |
| @click.option('--out-data', default=None, | |
| help='Custom OOD data CSV') | |
| @click.option('--max-t', type=int, default=1000, | |
| help='Maximum T to consider reconstructions from') | |
| @click.option('--min-t', type=int, default=0, | |
| help='Minimum T to consider reconstructions from') | |
| @click.option('--t-skip', type=int, default=1, | |
| help='Only use every n reconstructions') | |
| @click.option('--seed', type=int, default=2, | |
| help='Random seed to use') | |
| def ood_results(output_dir, model_name, out_data, max_t, min_t, t_skip, seed): | |
| """Process and visualize OOD detection results. | |
| Creates Excel reports, ROC curves, and PR curves for OOD analysis. | |
| Example: | |
| hae ood-results --output-dir ./results --model-name decodiff_mvtec | |
| hae ood-results --output-dir ./results --model-name model_name --out-data custom.csv | |
| """ | |
| console.print("[bold magenta]📊 OOD Results Processing[/bold magenta]\n") | |
| cmd = [ | |
| "python", "ood_results.py", | |
| "--output_dir", str(output_dir), | |
| "--model_name", model_name, | |
| "--max_t", str(max_t), | |
| "--min_t", str(min_t), | |
| "--t_skip", str(t_skip), | |
| "--seed", str(seed) | |
| ] | |
| if out_data: | |
| cmd.extend(["--out_data", out_data]) | |
| show_standalone_alternatives(cmd, "Process OOD detection results") | |
| console.print("[bold]Running:[/bold]", " ".join(cmd)) | |
| console.print() | |
| result = subprocess.run(cmd) | |
| sys.exit(result.returncode) | |
| @main.command('ood-graph') | |
| @click.option('--in-csv', required=True, type=click.Path(exists=True), | |
| help='Path to in-distribution CSV') | |
| @click.option('--out-csv', required=True, type=click.Path(exists=True), | |
| help='Path to out-of-distribution CSV') | |
| @click.option('--save', type=click.Path(), | |
| help='Path to save the figure (optional, shows otherwise)') | |
| def ood_graph(in_csv, out_csv, save): | |
| """Draw LPIPS / MSE comparison grids for OOD analysis. | |
| Creates 2x2 grid comparing in-distribution vs out-distribution samples | |
| across LPIPS and MSE metrics. | |
| Example: | |
| hae ood-graph --in-csv results/in.csv --out-csv results/out.csv | |
| hae ood-graph --in-csv results/in.csv --out-csv results/out.csv --save fig.png | |
| """ | |
| console.print("[bold green]📈 OOD Visualization[/bold green]\n") | |
| cmd = [ | |
| "python", "ood_graph.py", | |
| "--in_csv", str(in_csv), | |
| "--out_csv", str(out_csv) | |
| ] | |
| if save: | |
| cmd.extend(["--save", str(save)]) | |
| show_standalone_alternatives(cmd, "Draw LPIPS / MSE comparison grids") | |
| console.print("[bold]Running:[/bold]", " ".join(cmd)) | |
| console.print() | |
| result = subprocess.run(cmd) | |
| sys.exit(result.returncode) | |
| # ============================================================================ | |
| # Data Preprocessing Commands | |
| # ============================================================================ | |
| @main.group() | |
| def preprocess(): | |
| """Data preprocessing utilities. | |
| Commands for cropping, splitting, and annotating datasets. | |
| """ | |
| pass | |
| @preprocess.command('crop') | |
| @click.option('--input-dir', required=True, type=click.Path(exists=True), | |
| help='Input directory containing images') | |
| @click.option('--output-dir', required=True, type=click.Path(), | |
| help='Output directory for cropped patches') | |
| @click.option('--patch-size', required=True, type=int, | |
| help='Patch side length (e.g., 128)') | |
| @click.option('--patches-per-image', required=True, type=int, | |
| help='Number of patches to extract per image') | |
| @click.option('--is-grayscale', type=int, default=0, | |
| help='1 for grayscale, 0 for color') | |
| @click.option('--ext', default='png', | |
| help='Output file extension') | |
| @click.option('--seed', type=int, | |
| help='Random seed for reproducibility') | |
| def preprocess_crop(input_dir, output_dir, patch_size, patches_per_image, is_grayscale, ext, seed): | |
| """Extract random crops from images. | |
| Creates smaller patches from larger images, useful for augmentation | |
| and creating training datasets. | |
| Example: | |
| hae preprocess crop --input-dir ./data/raw --output-dir ./data/crops \\ | |
| --patch-size 128 --patches-per-image 2 | |
| hae preprocess crop --input-dir ./images --output-dir ./patches \\ | |
| --patch-size 256 --patches-per-image 5 --seed 42 | |
| """ | |
| console.print("[bold cyan]✂️ Random Crop Extraction[/bold cyan]\n") | |
| cmd = [ | |
| "python", "-m", "dioodmi.data.random_crop", | |
| "--input_dir", str(input_dir), | |
| "--output_dir", str(output_dir), | |
| "--patch_size", str(patch_size), | |
| "--patches_per_image", str(patches_per_image), | |
| "--is_grayscale", str(is_grayscale), | |
| "--ext", ext | |
| ] | |
| if seed is not None: | |
| cmd.extend(["--seed", str(seed)]) | |
| show_standalone_alternatives(cmd, "Extract random crops from images") | |
| console.print("[bold]Running:[/bold]", " ".join(cmd)) | |
| console.print() | |
| result = subprocess.run(cmd) | |
| sys.exit(result.returncode) | |
| @preprocess.command('split') | |
| @click.argument('folder', type=click.Path(exists=True)) | |
| @click.option('--valid-ratio', type=float, default=0.2, | |
| help='Ratio for validation set (default: 0.2)') | |
| @click.option('--test-ratio', type=float, default=0.01, | |
| help='Ratio for test set (default: 0.01)') | |
| def preprocess_split(folder, valid_ratio, test_ratio): | |
| """Split dataset into train/valid/test CSV files. | |
| Creates CSV files with image paths for train, validation, and test sets. | |
| Example: | |
| hae preprocess split ./data/images | |
| hae preprocess split ./data/images --valid-ratio 0.1 --test-ratio 0.05 | |
| """ | |
| console.print("[bold yellow]📂 Dataset Splitting[/bold yellow]\n") | |
| cmd = [ | |
| "python", "-m", "dioodmi.data.split_data", | |
| str(folder), | |
| "--valid_ratio", str(valid_ratio), | |
| "--test_ratio", str(test_ratio) | |
| ] | |
| show_standalone_alternatives(cmd, "Split dataset into train/valid/test sets") | |
| console.print("[bold]Running:[/bold]", " ".join(cmd)) | |
| console.print() | |
| result = subprocess.run(cmd) | |
| sys.exit(result.returncode) | |
| @preprocess.command('annotate') | |
| @click.option('--input-dir', '-i', required=True, type=click.Path(exists=True), | |
| help='Directory containing images to annotate') | |
| @click.option('--output-dir', '-o', required=True, type=click.Path(), | |
| help='Directory to save annotation JSON files') | |
| @click.option('--grid-size', '-g', type=int, default=128, | |
| help='Grid size for patches (default: 128)') | |
| @click.option('--extensions', '-e', multiple=True, | |
| default=['.png', '.jpg', '.jpeg', '.bmp', '.tiff', '.tif'], | |
| help='Image file extensions to process') | |
| @click.option('--single-image', '-s', type=click.Path(exists=True), | |
| help='Process a single image instead of directory') | |
| @click.option('--is-defective', is_flag=True, default=False, | |
| help='Mark images as defective') | |
| @click.option('--random-sample', '-r', type=int, | |
| help='Randomly sample N files from directory') | |
| @click.option('--random-seed', type=int, | |
| help='Random seed for reproducible sampling') | |
| def preprocess_annotate(input_dir, output_dir, grid_size, extensions, single_image, | |
| is_defective, random_sample, random_seed): | |
| """Create annotation JSON files for anomaly detection. | |
| Generates JSON annotations with defective patch information for images. | |
| Example: | |
| hae preprocess annotate -i ./images -o ./annotations | |
| hae preprocess annotate -i ./images -o ./annotations --is-defective | |
| hae preprocess annotate -i ./images -o ./annotations -r 100 --random-seed 42 | |
| """ | |
| console.print("[bold magenta]📝 Annotation Creation[/bold magenta]\n") | |
| cmd = ["python", "scripts/create_annotations.py"] | |
| if single_image: | |
| cmd.extend(["--single-image", str(single_image)]) | |
| else: | |
| cmd.extend(["--input-dir", str(input_dir)]) | |
| cmd.extend([ | |
| "--output-dir", str(output_dir), | |
| "--grid-size", str(grid_size) | |
| ]) | |
| if extensions and extensions != ('.png', '.jpg', '.jpeg', '.bmp', '.tiff', '.tif'): | |
| for ext in extensions: | |
| cmd.extend(["--extensions", ext]) | |
| if is_defective: | |
| cmd.append("--is-defective") | |
| if random_sample is not None: | |
| cmd.extend(["--random-sample", str(random_sample)]) | |
| if random_seed is not None: | |
| cmd.extend(["--random-seed", str(random_seed)]) | |
| show_standalone_alternatives(cmd, "Create annotation JSON files") | |
| console.print("[bold]Running:[/bold]", " ".join(cmd)) | |
| console.print() | |
| result = subprocess.run(cmd) | |
| sys.exit(result.returncode) | |
| @preprocess.group('synthetic-artifacts') | |
| def preprocess_synthetic_artifacts(): | |
| """Generate synthetic defect datasets (scratches, circular artifacts, etc.).""" | |
| pass | |
| @preprocess_synthetic_artifacts.command('scratch') | |
| @click.option('--input-dir', required=True, type=click.Path(exists=True), | |
| help='Directory containing source images (PNG expected)') | |
| @click.option('--output-dir', required=True, type=click.Path(), | |
| help='Directory where synthetic dataset will be written') | |
| @click.option('--severities-json', required=True, type=click.Path(exists=True), | |
| help='JSON file describing severity presets') | |
| @click.option('--seed', type=int, | |
| help='Optional RNG seed for reproducibility') | |
| @click.option('--dry-run/--no-dry-run', default=False, | |
| help='Show commands without executing') | |
| def preprocess_synthetic_artifacts_scratch(input_dir, output_dir, severities_json, seed, dry_run): | |
| """Generate synthetic scratch defects from clean source images. | |
| Wraps dioodmi's synthetic scratch generator to produce defect variants, | |
| masks, and metadata according to severity presets. | |
| Example: | |
| hae preprocess synthetic-artifacts scratch --input-dir ./clean --output-dir ./scratch_out \\ | |
| --severities-json severities_circular_carpet.json | |
| """ | |
| console.print("[bold cyan]🪛 Generating Synthetic Scratches[/bold cyan]\n") | |
| cmd = [ | |
| "python", "-m", "dioodmi.data.synthetic_scratch", | |
| "--input_dir", str(input_dir), | |
| "--output_dir", str(output_dir), | |
| "--severities_json", str(severities_json) | |
| ] | |
| if seed is not None: | |
| cmd.extend(["--seed", str(seed)]) | |
| show_standalone_alternatives(cmd, "Generate synthetic scratch dataset") | |
| if dry_run: | |
| console.print("[yellow]Dry run - not executing[/yellow]") | |
| return | |
| console.print("[bold]Running:[/bold]", " ".join(cmd)) | |
| console.print() | |
| result = subprocess.run(cmd) | |
| sys.exit(result.returncode) | |
| @preprocess_synthetic_artifacts.command('circular') | |
| @click.option('--input-dir', required=True, type=click.Path(exists=True), | |
| help='Directory containing source images (PNG expected)') | |
| @click.option('--output-dir', required=True, type=click.Path(), | |
| help='Directory where synthetic dataset will be written') | |
| @click.option('--severities-json', required=True, type=click.Path(exists=True), | |
| help='JSON file describing severity presets') | |
| @click.option('--seed', type=int, | |
| help='Optional RNG seed for reproducibility') | |
| @click.option('--mvtec-format/--no-mvtec-format', default=False, | |
| help='Output using MVTec AD folder structure') | |
| @click.option('--dry-run/--no-dry-run', default=False, | |
| help='Show commands without executing') | |
| def preprocess_synthetic_artifacts_circular(input_dir, output_dir, severities_json, seed, mvtec_format, dry_run): | |
| """Generate synthetic circular artifacts from clean source images. | |
| Wraps dioodmi's circular artifact generator to create severity-based | |
| datasets with optional MVTec-style folder layout. | |
| Example: | |
| hae preprocess synthetic-artifacts circular --input-dir ./clean --output-dir ./circ_out \\ | |
| --severities-json severities_circular_carpet.json --mvtec-format | |
| """ | |
| console.print("[bold cyan]🟢 Generating Synthetic Circular Artifacts[/bold cyan]\n") | |
| cmd = [ | |
| "python", "-m", "dioodmi.data.synthetic_circular_artifact", | |
| "--input_dir", str(input_dir), | |
| "--output_dir", str(output_dir), | |
| "--severities_json", str(severities_json) | |
| ] | |
| if seed is not None: | |
| cmd.extend(["--seed", str(seed)]) | |
| if mvtec_format: | |
| cmd.append("--mvtec_format") | |
| show_standalone_alternatives(cmd, "Generate synthetic circular artifact dataset") | |
| if dry_run: | |
| console.print("[yellow]Dry run - not executing[/yellow]") | |
| return | |
| console.print("[bold]Running:[/bold]", " ".join(cmd)) | |
| console.print() | |
| result = subprocess.run(cmd) | |
| sys.exit(result.returncode) | |
| @main.group() | |
| def data(): | |
| """Data manipulation and analysis utilities. | |
| Commands for processing, transforming, and analyzing anomaly detection datasets. | |
| """ | |
| pass | |
| @data.command('gen-masks') | |
| @click.option('--latent-pad', default=None, help='Padding tuple for latent space') | |
| @click.option('--ddpm-checkpoint-epoch', default=None, help='Specific checkpoint epoch') | |
| @click.option('--reconstruct-ids', default=None, help='File with test image path') | |
| @click.option('--output-dir', default='./results', type=click.Path(), help='Output directory') | |
| @click.option('--model-name', default='oneimage', help='Model run name') | |
| @click.option('--augmentation', type=int, default=0, help='Use augmentation') | |
| @click.option('--num-workers', type=int, default=0, help='Number of data loading workers') | |
| @click.option('--cache-data', type=int, default=0, help='Cache dataset in memory') | |
| @click.option('--drop-last', type=int, default=0, help='Drop last incomplete batch') | |
| @click.option('--first-n', default=None, help='Process only first N samples') | |
| @click.option('--is-grayscale', type=int, default=0, help='1 for grayscale images') | |
| @click.option('--spatial-dimension', type=int, default=2, help='2D or 3D data') | |
| @click.option('--image-size', type=int, default=None, help='Target image size') | |
| @click.option('--image-roi', default=None, help='Region of interest') | |
| @click.option('--vqvae-checkpoint', default=None, help='VQVAE checkpoint path') | |
| @click.option('--model-type', default='small', help='Model size variant') | |
| @click.option('--scheduler', default='ddpm', help='Diffusion scheduler') | |
| @click.option('--prediction-type', default='epsilon', help='Prediction type') | |
| @click.option('--beta-schedule', default='linear_beta', help='Beta schedule') | |
| @click.option('--beta-start', type=float, default=1e-4, help='Beta start value') | |
| @click.option('--beta-end', type=float, default=2e-2, help='Beta end value') | |
| @click.option('--b-scale', type=float, default=1.0, help='Beta scale factor') | |
| @click.option('--snr-shift', type=float, default=1.0, help='SNR shift value') | |
| @click.option('--simplex-noise', type=int, default=0, help='Use simplex noise') | |
| @click.option('--show-diff', type=int, default=0, help='Show difference images') | |
| @click.option('--batch-size', type=int, default=10, help='Batch size') | |
| @click.option('--num-train-timesteps', type=int, default=1000, help='Number of diffusion steps') | |
| @click.option('--t-destruct', type=int, default=200, help='Destruction timestep') | |
| @click.option('--t-inpaint', type=int, default=50, help='Inpainting timestep') | |
| @click.option('--resample-steps', type=int, default=4, help='Resampling steps') | |
| @click.option('--binary-threshold', type=int, default=50, help='Binary mask threshold') | |
| @click.option('--detect-threshold', type=float, default=10, help='Detection threshold') | |
| @click.option('--anomalymap-threshold', type=int, default=40, help='Anomaly map threshold') | |
| @click.option('--masking-threshold', type=int, default=-1, help='Masking threshold') | |
| @click.option('--mask-percentile', type=float, default=95, help='Mask percentile') | |
| @click.option('--kernel-size', type=int, default=3, help='Morphological kernel size') | |
| @click.option('--brightness-scale', type=int, default=0, help='Brightness scaling') | |
| @click.option('--mask-color', default='black', help='Mask color') | |
| @click.option('--mask-anomaly-map', type=int, default=0, help='Generate masked anomaly maps') | |
| @click.option('--detect-anomaly-map', type=int, default=0, help='Generate detection maps') | |
| @click.option('--dry-run/--no-dry-run', default=False, help='Show command without executing') | |
| def data_gen_masks(**kwargs): | |
| """Generate binary masks from anomaly maps. | |
| Uses trained diffusion models to reconstruct images and generate | |
| binary masks highlighting anomalous regions. | |
| Examples: | |
| hae data gen-masks --model-name bottle_model --output-dir ./masks | |
| hae data gen-masks --reconstruct-ids test.txt --binary-threshold 40 | |
| """ | |
| console.print("[cyan]🎭 Generating Binary Masks[/cyan]\n") | |
| # Build command - delegate to original script | |
| project_root = Path(__file__).parent.parent.parent.parent.parent | |
| script_path = project_root / "gen_mask_images.py" | |
| cmd = ["python", str(script_path)] | |
| # Convert kwargs to command-line arguments | |
| for key, value in kwargs.items(): | |
| if key == 'dry_run': | |
| continue | |
| if value is not None and value != '' and value is not False: | |
| # Convert underscore to dash | |
| arg_name = key.replace('_', '-') | |
| if isinstance(value, bool): | |
| if value: | |
| cmd.append(f"--{arg_name}") | |
| else: | |
| cmd.extend([f"--{arg_name}", str(value)]) | |
| show_standalone_alternatives(cmd, "Generate masks from anomaly maps") | |
| if kwargs.get('dry_run'): | |
| console.print("[yellow]Dry run - not executing[/yellow]") | |
| return | |
| console.print("[bold]Running:[/bold]", " ".join(cmd)) | |
| console.print() | |
| result = subprocess.run(cmd) | |
| sys.exit(result.returncode) | |
| @data.command('save') | |
| @click.argument('args', nargs=-1, type=click.UNPROCESSED) | |
| @click.option('--dry-run/--no-dry-run', default=False, help='Show command without executing') | |
| def data_save(args, dry_run): | |
| """Save reconstructed images from diffusion model. | |
| This command wraps save_images.py with all its original arguments. | |
| Run 'hae data save --help' to see options or use --dry-run to preview. | |
| Examples: | |
| hae data save --model-name bottle_model --output-dir ./output | |
| hae data save --dry-run # Show equivalent command | |
| """ | |
| console.print("[cyan]💾 Saving Reconstructed Images[/cyan]\n") | |
| project_root = Path(__file__).parent.parent.parent.parent.parent | |
| script_path = project_root / "save_images.py" | |
| cmd = ["python", str(script_path)] + list(args) | |
| show_standalone_alternatives(cmd, "Save reconstructed images") | |
| if dry_run: | |
| console.print("[yellow]Dry run - not executing[/yellow]") | |
| return | |
| console.print("[bold]Running:[/bold]", " ".join(cmd)) | |
| console.print() | |
| result = subprocess.run(cmd) | |
| sys.exit(result.returncode) | |
| @data.command('save-pairs') | |
| @click.argument('args', nargs=-1, type=click.UNPROCESSED) | |
| @click.option('--dry-run/--no-dry-run', default=False, help='Show command without executing') | |
| def data_save_pairs(args, dry_run): | |
| """Save pairs of original and masked images. | |
| This command wraps save_masked_pairs.py with all its original arguments. | |
| Examples: | |
| hae data save-pairs --src-dir ./images --dst ./pairs | |
| hae data save-pairs --dry-run # Show equivalent command | |
| """ | |
| console.print("[cyan]🖼️ Saving Original/Masked Image Pairs[/cyan]\n") | |
| project_root = Path(__file__).parent.parent.parent.parent.parent | |
| script_path = project_root / "save_masked_pairs.py" | |
| cmd = ["python", str(script_path)] + list(args) | |
| show_standalone_alternatives(cmd, "Save original/masked image pairs") | |
| if dry_run: | |
| console.print("[yellow]Dry run - not executing[/yellow]") | |
| return | |
| console.print("[bold]Running:[/bold]", " ".join(cmd)) | |
| console.print() | |
| result = subprocess.run(cmd) | |
| sys.exit(result.returncode) | |
| @data.command('compare-contours') | |
| @click.option('--gt', required=True, type=click.Path(exists=True), help='Ground-truth JSONL file') | |
| @click.option('--pred', required=True, type=click.Path(exists=True), help='Prediction JSONL file') | |
| @click.option('--pixth', type=int, default=4, help='Min pixel intersection threshold') | |
| @click.option('--out-dir', required=True, type=click.Path(), help='Output directory') | |
| @click.option('--dry-run/--no-dry-run', default=False, help='Show command without executing') | |
| def data_compare_contours(gt, pred, pixth, out_dir, dry_run): | |
| """Compare predicted contours against ground truth. | |
| Evaluates contour prediction accuracy by computing pixel-wise | |
| intersection between ground truth and predicted anomaly contours. | |
| Examples: | |
| hae data compare-contours --gt truth.jsonl --pred pred.jsonl --out-dir ./eval | |
| hae data compare-contours --gt truth.jsonl --pred pred.jsonl --pixth 10 --out-dir ./eval | |
| """ | |
| console.print("[cyan]📊 Comparing Contours[/cyan]\n") | |
| project_root = Path(__file__).parent.parent.parent.parent.parent | |
| script_path = project_root / "compare_contours.py" | |
| cmd = [ | |
| "python", str(script_path), | |
| "--gt", gt, | |
| "--pred", pred, | |
| "--pixth", str(pixth), | |
| "--out_dir", out_dir | |
| ] | |
| show_standalone_alternatives(cmd, "Compare predicted vs ground-truth contours") | |
| if dry_run: | |
| console.print("[yellow]Dry run - not executing[/yellow]") | |
| return | |
| console.print("[bold]Running:[/bold]", " ".join(cmd)) | |
| console.print() | |
| result = subprocess.run(cmd) | |
| sys.exit(result.returncode) | |
| @data.command('copy-from-csv') | |
| @click.argument('csv_file', type=click.Path(exists=True)) | |
| @click.argument('output_dir', type=click.Path()) | |
| @click.option('--dry-run/--no-dry-run', default=False, help='Show command without executing') | |
| def data_copy_from_csv(csv_file, output_dir, dry_run): | |
| """Copy images listed in CSV to output directory. | |
| Reads a CSV file containing image paths and copies all listed | |
| images to the specified output directory. | |
| Examples: | |
| hae data copy-from-csv data/test.csv ./test_images | |
| hae data copy-from-csv data/train.csv ./train_images --dry-run | |
| """ | |
| console.print("[cyan]📁 Copying Images from CSV[/cyan]\n") | |
| project_root = Path(__file__).parent.parent.parent.parent.parent | |
| script_path = project_root / "copy-images-in-csv.sh" | |
| cmd = ["bash", str(script_path), csv_file, output_dir] | |
| show_standalone_alternatives(cmd, "Copy images listed in CSV") | |
| if dry_run: | |
| console.print("[yellow]Dry run - not executing[/yellow]") | |
| return | |
| console.print("[bold]Running:[/bold]", " ".join(cmd)) | |
| console.print() | |
| result = subprocess.run(cmd) | |
| sys.exit(result.returncode) | |
| @main.group() | |
| def model(): | |
| """Model and checkpoint management utilities. | |
| Commands for inspecting checkpoints, managing model files, and checkpoint utilities. | |
| """ | |
| pass | |
| @model.command('info') | |
| @click.argument('checkpoint_path', type=click.Path(exists=True)) | |
| @click.option('--json', 'output_json', is_flag=True, help='Output in JSON format') | |
| @click.option('--dry-run/--no-dry-run', default=False, help='Show command without executing') | |
| def model_info(checkpoint_path, output_json, dry_run): | |
| """Display checkpoint information (epoch, loss, parameters). | |
| Inspects a checkpoint file and displays its metadata including | |
| epoch number, loss values, parameter count, and file size. | |
| Examples: | |
| hae model info ./checkpoints/best.pt | |
| hae model info ./checkpoints/best.pt --json | |
| """ | |
| console.print("[cyan]📦 Checkpoint Information[/cyan]\n") | |
| project_root = Path(__file__).parent.parent.parent.parent.parent | |
| script_path = project_root / "checkpoint_info.py" | |
| cmd = ["python", str(script_path), checkpoint_path] | |
| if output_json: | |
| cmd.append("--json") | |
| show_standalone_alternatives(cmd, "Inspect checkpoint file") | |
| if dry_run: | |
| console.print("[yellow]Dry run - not executing[/yellow]") | |
| return | |
| console.print("[bold]Running:[/bold]", " ".join(cmd)) | |
| console.print() | |
| result = subprocess.run(cmd) | |
| sys.exit(result.returncode) | |
| @model.command('utils') | |
| @click.argument('path', type=click.Path(exists=True)) | |
| @click.option('--list', 'list_mode', is_flag=True, help='List all checkpoints in directory') | |
| @click.option('--dry-run/--no-dry-run', default=False, help='Show command without executing') | |
| def model_utils(path, list_mode, dry_run): | |
| """Checkpoint inspection utilities. | |
| Inspect a single checkpoint file or list all checkpoints in a directory | |
| with their epoch, loss, and size information. | |
| Examples: | |
| hae model utils ./checkpoints/model.pt | |
| hae model utils ./checkpoints --list | |
| """ | |
| console.print("[cyan]🔧 Checkpoint Utilities[/cyan]\n") | |
| project_root = Path(__file__).parent.parent.parent.parent.parent | |
| script_path = project_root / "checkpoint_utils.py" | |
| cmd = ["python", str(script_path), path] | |
| if list_mode: | |
| cmd.append("--list") | |
| show_standalone_alternatives(cmd, "Inspect checkpoint utilities") | |
| if dry_run: | |
| console.print("[yellow]Dry run - not executing[/yellow]") | |
| return | |
| console.print("[bold]Running:[/bold]", " ".join(cmd)) | |
| console.print() | |
| result = subprocess.run(cmd) | |
| sys.exit(result.returncode) | |
| @model.command('save-ddpm') | |
| @click.argument('args', nargs=-1, type=click.UNPROCESSED) | |
| @click.option('--dry-run/--no-dry-run', default=False, help='Show command without executing') | |
| def model_save_ddpm(args, dry_run): | |
| """Save original DDPM model format. | |
| Converts and saves DDPM models in original format. | |
| This command wraps orig_ddpm_save.py with all its original arguments. | |
| Examples: | |
| hae model save-ddpm --model-name bottle --output-dir ./models | |
| hae model save-ddpm --dry-run # Show equivalent command | |
| """ | |
| console.print("[cyan]💾 Saving DDPM Model[/cyan]\n") | |
| project_root = Path(__file__).parent.parent.parent.parent.parent | |
| script_path = project_root / "orig_ddpm_save.py" | |
| cmd = ["python", str(script_path)] + list(args) | |
| show_standalone_alternatives(cmd, "Save DDPM model in original format") | |
| if dry_run: | |
| console.print("[yellow]Dry run - not executing[/yellow]") | |
| return | |
| console.print("[bold]Running:[/bold]", " ".join(cmd)) | |
| console.print() | |
| result = subprocess.run(cmd) | |
| sys.exit(result.returncode) | |
| @model.command('position-decoder') | |
| @click.argument('args', nargs=-1, type=click.UNPROCESSED) | |
| @click.option('--dry-run/--no-dry-run', default=False, help='Show command without executing') | |
| def model_position_decoder(args, dry_run): | |
| """Decode position-encoded patches. | |
| Utility for decoding position information from position-encoded image patches. | |
| Provides functions for extracting crop coordinates and image IDs from patches. | |
| Examples: | |
| hae model position-decoder --help # Show available functions | |
| hae model position-decoder --dry-run | |
| """ | |
| console.print("[cyan]🔍 Position Decoder Utility[/cyan]\n") | |
| project_root = Path(__file__).parent.parent.parent.parent.parent | |
| script_path = project_root / "scripts" / "position_decoder.py" | |
| cmd = ["python", str(script_path)] + list(args) | |
| show_standalone_alternatives(cmd, "Decode position-encoded patches") | |
| if dry_run: | |
| console.print("[yellow]Dry run - not executing[/yellow]") | |
| return | |
| console.print("[bold]Running:[/bold]", " ".join(cmd)) | |
| console.print() | |
| result = subprocess.run(cmd) | |
| sys.exit(result.returncode) | |
| @main.group() | |
| def system(): | |
| """System diagnostics and health checks. | |
| Commands for checking system configuration, dependencies, and environment setup. | |
| """ | |
| pass | |
| @system.command('check') | |
| @click.option('--verbose', '-v', is_flag=True, help='Show detailed information') | |
| @click.option('--json', 'output_json', is_flag=True, help='Output in JSON format') | |
| def system_check(verbose, output_json): | |
| """Comprehensive system health check. | |
| Checks Python environment, CUDA/GPU setup, dependencies, disk space, | |
| and DiOodMi project configuration. | |
| Examples: | |
| hae system check | |
| hae system check --verbose | |
| hae system check --json | |
| """ | |
| import sys | |
| import platform | |
| import shutil | |
| import json as json_lib | |
| console.print("[bold cyan]🔍 DiOodMi System Health Check[/bold cyan]\n") | |
| checks = {} | |
| all_pass = True | |
| # Python Environment | |
| console.print("[bold]Python Environment:[/bold]") | |
| py_version = f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}" | |
| py_pass = sys.version_info >= (3, 10) | |
| checks['python'] = {'version': py_version, 'pass': py_pass} | |
| console.print(f" {'✓' if py_pass else '✗'} Python {py_version} {'' if py_pass else '(requires >=3.10)'}") | |
| if verbose: | |
| console.print(f" Executable: {sys.executable}") | |
| console.print(f" Platform: {platform.platform()}") | |
| # PyTorch and CUDA | |
| console.print("\n[bold]PyTorch & CUDA:[/bold]") | |
| try: | |
| import torch | |
| torch_version = torch.__version__ | |
| cuda_available = torch.cuda.is_available() | |
| cuda_version = torch.version.cuda if cuda_available else "N/A" | |
| num_gpus = torch.cuda.device_count() if cuda_available else 0 | |
| checks['pytorch'] = { | |
| 'version': torch_version, | |
| 'cuda_available': cuda_available, | |
| 'cuda_version': cuda_version, | |
| 'num_gpus': num_gpus, | |
| 'pass': True | |
| } | |
| console.print(f" ✓ PyTorch {torch_version}") | |
| console.print(f" {'✓' if cuda_available else '✗'} CUDA available: {cuda_available}") | |
| if cuda_available: | |
| console.print(f" ✓ CUDA version: {cuda_version}") | |
| console.print(f" ✓ GPUs detected: {num_gpus}") | |
| if verbose: | |
| for i in range(num_gpus): | |
| name = torch.cuda.get_device_name(i) | |
| props = torch.cuda.get_device_properties(i) | |
| console.print(f" GPU {i}: {name}") | |
| console.print(f" Compute: {props.major}.{props.minor}, Memory: {props.total_memory / 1024**3:.1f}GB") | |
| except ImportError: | |
| checks['pytorch'] = {'pass': False, 'error': 'PyTorch not installed'} | |
| console.print(" ✗ PyTorch not installed") | |
| all_pass = False | |
| # Key Dependencies | |
| console.print("\n[bold]Key Dependencies:[/bold]") | |
| deps = [ | |
| ('monai', 'MONAI'), | |
| ('diffusers', 'Diffusers'), | |
| ('transformers', 'Transformers'), | |
| ('click', 'Click'), | |
| ('rich', 'Rich'), | |
| ] | |
| checks['dependencies'] = {} | |
| for module, name in deps: | |
| try: | |
| mod = __import__(module) | |
| version = getattr(mod, '__version__', 'unknown') | |
| console.print(f" ✓ {name} {version}") | |
| checks['dependencies'][module] = {'version': version, 'pass': True} | |
| except ImportError: | |
| console.print(f" ✗ {name} not installed") | |
| checks['dependencies'][module] = {'pass': False, 'error': 'not installed'} | |
| all_pass = False | |
| # Disk Space | |
| console.print("\n[bold]Disk Space:[/bold]") | |
| project_root = Path(__file__).parent.parent.parent.parent.parent | |
| total, used, free = shutil.disk_usage(project_root) | |
| free_gb = free / (1024**3) | |
| disk_pass = free_gb > 10 # Warn if less than 10GB free | |
| checks['disk'] = { | |
| 'free_gb': round(free_gb, 2), | |
| 'total_gb': round(total / (1024**3), 2), | |
| 'pass': disk_pass | |
| } | |
| console.print(f" {'✓' if disk_pass else '⚠'} Free space: {free_gb:.1f} GB {'' if disk_pass else '(low disk space)'}") | |
| if verbose: | |
| console.print(f" Total: {total / (1024**3):.1f} GB") | |
| console.print(f" Used: {used / (1024**3):.1f} GB") | |
| # Project Structure | |
| console.print("\n[bold]Project Structure:[/bold]") | |
| critical_paths = [ | |
| ('project/', 'Core codebase'), | |
| ('workspaces/hae-py/', 'HAE automation'), | |
| ('tests/', 'Test suite'), | |
| ('.git/', 'Git repository'), | |
| ] | |
| checks['project_structure'] = {} | |
| for path, desc in critical_paths: | |
| full_path = project_root / path | |
| exists = full_path.exists() | |
| console.print(f" {'✓' if exists else '✗'} {desc}: {path}") | |
| checks['project_structure'][path] = {'exists': exists, 'pass': exists} | |
| if not exists: | |
| all_pass = False | |
| # Summary | |
| console.print() | |
| if all_pass: | |
| console.print("[bold green]✓ All checks passed![/bold green]") | |
| else: | |
| console.print("[bold yellow]⚠ Some checks failed - see details above[/bold yellow]") | |
| checks['summary'] = {'all_pass': all_pass} | |
| if output_json: | |
| console.print("\n[bold]JSON Output:[/bold]") | |
| print(json_lib.dumps(checks, indent=2)) | |
| sys.exit(0 if all_pass else 1) | |
| @main.group() | |
| def verify(): | |
| """Regression testing and baseline management. | |
| Commands for managing golden reference baselines, comparing evaluation | |
| outputs, and running regression tests to detect when code changes affect results. | |
| """ | |
| pass | |
| @verify.command('create-baseline') | |
| @click.option('--version', required=True, help='Baseline version (e.g., v1.0, v1.1)') | |
| @click.option('--method', type=click.Choice(['method1', 'method2', 'both']), | |
| default='both', help='Which evaluation method(s) to run') | |
| @click.option('--description', default='', help='Description of this baseline') | |
| @click.option('--dataset', default='mvtec', help='Dataset name') | |
| @click.option('--data-dir', default='./mvtec-dataset/', help='Data directory path') | |
| @click.option('--object-class', default='wood', help='Object class') | |
| @click.option('--anomaly-class', default='hole', help='Anomaly class') | |
| @click.option('--model-size', default='UNet_L', help='Model size') | |
| @click.option('--patch-size', type=int, default=1024, help='Patch size') | |
| @click.option('--pretrained', default='./results/decodiff_mvtec/checkpoints/MVTEC-AD-model.pt', | |
| help='Pretrained model path') | |
| @click.option('--annotation-dir', default='./annotations_wood_test_hole', | |
| help='Annotation directory') | |
| @click.option('--reverse-steps', type=int, default=5, help='Reverse diffusion steps') | |
| @click.option('--batch-size', type=int, default=1, help='Batch size') | |
| @click.option('--vae-type', default='ema', help='VAE type') | |
| @click.option('--image-size', type=int, default=1024, help='Image size') | |
| @click.option('--center-size', type=int, default=1024, help='Center crop size') | |
| @click.option('--dry-run/--no-dry-run', default=False, help='Show command without executing') | |
| def verify_create_baseline(version, method, description, dataset, data_dir, object_class, | |
| anomaly_class, model_size, patch_size, pretrained, annotation_dir, | |
| reverse_steps, batch_size, vae_type, image_size, center_size, | |
| dry_run): | |
| """Generate golden reference baseline for regression testing. | |
| This command runs evaluation methods (decodiff_evaluator and/or | |
| decodiff_evaluate_process), collects their outputs (NPY, JSON, confusion | |
| matrices), creates a versioned baseline directory, and prepares it for DVC tracking. | |
| Examples: | |
| hae verify create-baseline --version v1.0 --method both | |
| hae verify create-baseline --version v1.1 --method method1 | |
| hae verify create-baseline --version v1.1 --object-class bottle --anomaly-class crack | |
| """ | |
| console.print("[cyan]📦 Generating Golden Reference Baseline[/cyan]\n") | |
| project_root = Path(__file__).parent.parent.parent.parent.parent | |
| script_path = project_root / "scripts" / "generate_baseline.py" | |
| cmd = [ | |
| "python", str(script_path), | |
| "--version", version, | |
| "--method", method, | |
| "--description", description, | |
| "--dataset", dataset, | |
| "--data-dir", data_dir, | |
| "--object-class", object_class, | |
| "--anomaly-class", anomaly_class, | |
| "--model-size", model_size, | |
| "--patch-size", str(patch_size), | |
| "--pretrained", pretrained, | |
| "--annotation-dir", annotation_dir, | |
| "--reverse-steps", str(reverse_steps), | |
| "--batch-size", str(batch_size), | |
| "--vae-type", vae_type, | |
| "--image-size", str(image_size), | |
| "--center-size", str(center_size) | |
| ] | |
| show_standalone_alternatives(cmd, "Generate golden reference baseline") | |
| if dry_run: | |
| console.print("[yellow]Dry run - not executing[/yellow]") | |
| return | |
| result = subprocess.run(cmd) | |
| sys.exit(result.returncode) | |
| @verify.command('approve-baseline') | |
| @click.option('--version', required=True, help='Baseline version to approve (e.g., v1.1)') | |
| @click.option('--auto-approve', is_flag=True, help='Skip approval confirmation') | |
| @click.option('--skip-dvc', is_flag=True, help='Skip DVC tracking (for testing)') | |
| @click.option('--skip-git', is_flag=True, help='Skip git commit (for testing)') | |
| @click.option('--dry-run/--no-dry-run', default=False, help='Show command without executing') | |
| def verify_approve_baseline(version, auto_approve, skip_dvc, skip_git, dry_run): | |
| """Approve and finalize a new golden reference baseline. | |
| This command compares the candidate baseline with the current baseline, | |
| shows differences and statistics, prompts for approval, updates the | |
| 'current' symlink, tracks the baseline with DVC, and commits changes to git. | |
| Examples: | |
| hae verify approve-baseline --version v1.1 | |
| hae verify approve-baseline --version v1.1 --auto-approve | |
| hae verify approve-baseline --version v1.1 --skip-dvc --skip-git | |
| """ | |
| console.print("[cyan]✅ Approving Golden Reference Baseline[/cyan]\n") | |
| project_root = Path(__file__).parent.parent.parent.parent.parent | |
| script_path = project_root / "scripts" / "approve_baseline.py" | |
| cmd = ["python", str(script_path), "--version", version] | |
| if auto_approve: | |
| cmd.append("--auto-approve") | |
| if skip_dvc: | |
| cmd.append("--skip-dvc") | |
| if skip_git: | |
| cmd.append("--skip-git") | |
| show_standalone_alternatives(cmd, "Approve and finalize baseline") | |
| if dry_run: | |
| console.print("[yellow]Dry run - not executing[/yellow]") | |
| return | |
| result = subprocess.run(cmd) | |
| sys.exit(result.returncode) | |
| @verify.command('compare') | |
| @click.argument('baseline_dir', type=click.Path(exists=True)) | |
| @click.argument('candidate_dir', type=click.Path(exists=True)) | |
| @click.option('--method', type=click.Choice(['method1', 'method2', 'both']), | |
| default='both', help='Which method outputs to compare') | |
| @click.option('--verbose', '-v', is_flag=True, help='Show detailed comparison results') | |
| @click.option('--json', 'output_json', is_flag=True, help='Output results in JSON format') | |
| @click.option('--tolerance', type=float, help='Custom tolerance for numerical comparisons') | |
| def verify_compare(baseline_dir, candidate_dir, method, verbose, output_json, tolerance): | |
| """Compare two baseline directories. | |
| Compares outputs (NPY arrays, JSON metrics, confusion matrices) between | |
| two baseline directories and shows detailed differences. Useful for | |
| reviewing changes before approving a new baseline. | |
| Examples: | |
| hae verify compare tests/golden/v1.0 tests/golden/v1.1 | |
| hae verify compare tests/golden/v1.0 tests/golden/v1.1 --method method1 --verbose | |
| hae verify compare tests/golden/v1.0 tests/golden/v1.1 --json | |
| """ | |
| console.print("[cyan]🔍 Comparing Baselines[/cyan]\n") | |
| baseline_path = Path(baseline_dir) | |
| candidate_path = Path(candidate_dir) | |
| console.print(f"[bold]Baseline:[/bold] {baseline_path}") | |
| console.print(f"[bold]Candidate:[/bold] {candidate_path}\n") | |
| # Import comparison utilities | |
| project_root = Path(__file__).parent.parent.parent.parent.parent | |
| sys.path.insert(0, str(project_root)) | |
| try: | |
| from project.utils.comparison import ( | |
| compare_directories, | |
| summarize_comparison_results, | |
| ToleranceConfig | |
| ) | |
| except ImportError as e: | |
| console.print(f"[red]✗ Failed to import comparison utilities: {e}[/red]") | |
| sys.exit(1) | |
| # Configure tolerance | |
| if tolerance: | |
| tolerance_config = ToleranceConfig( | |
| npy_atol=tolerance, | |
| npy_rtol=tolerance, | |
| metrics_atol=tolerance, | |
| metrics_rtol=tolerance | |
| ) | |
| else: | |
| tolerance_config = ToleranceConfig() | |
| all_pass = True | |
| # Compare Method 1 | |
| if method in ['method1', 'both']: | |
| console.print("[bold]Method 1 Comparison:[/bold]") | |
| console.print("-" * 80) | |
| method1_baseline = baseline_path / "method1" | |
| method1_candidate = candidate_path / "method1" | |
| if method1_baseline.exists() and method1_candidate.exists(): | |
| results = compare_directories( | |
| method1_baseline, | |
| method1_candidate, | |
| tolerance_config | |
| ) | |
| summary = summarize_comparison_results(results, verbose=verbose) | |
| console.print(summary) | |
| # Check if any comparison failed | |
| for result in results: | |
| if not result.passed: | |
| all_pass = False | |
| else: | |
| console.print("[yellow]⚠ Method 1 directories not found in one or both baselines[/yellow]") | |
| all_pass = False | |
| console.print() | |
| # Compare Method 2 | |
| if method in ['method2', 'both']: | |
| console.print("[bold]Method 2 Comparison:[/bold]") | |
| console.print("-" * 80) | |
| method2_baseline = baseline_path / "method2" | |
| method2_candidate = candidate_path / "method2" | |
| if method2_baseline.exists() and method2_candidate.exists(): | |
| results = compare_directories( | |
| method2_baseline, | |
| method2_candidate, | |
| tolerance_config | |
| ) | |
| summary = summarize_comparison_results(results, verbose=verbose) | |
| console.print(summary) | |
| # Check if any comparison failed | |
| for result in results: | |
| if not result.passed: | |
| all_pass = False | |
| else: | |
| console.print("[yellow]⚠ Method 2 directories not found in one or both baselines[/yellow]") | |
| all_pass = False | |
| console.print() | |
| # Summary | |
| if all_pass: | |
| console.print("[bold green]✓ All comparisons passed![/bold green]") | |
| sys.exit(0) | |
| else: | |
| console.print("[bold yellow]⚠ Some comparisons failed - see details above[/bold yellow]") | |
| sys.exit(1) | |
| @verify.command('list') | |
| @click.option('--verbose', '-v', is_flag=True, help='Show detailed information') | |
| @click.option('--json', 'output_json', is_flag=True, help='Output in JSON format') | |
| def verify_list(verbose, output_json): | |
| """List available golden reference baselines. | |
| Shows all versioned baselines in tests/golden/ directory with metadata | |
| information (version, date, description, git commit). | |
| Examples: | |
| hae verify list | |
| hae verify list --verbose | |
| hae verify list --json | |
| """ | |
| console.print("[cyan]📋 Available Golden Reference Baselines[/cyan]\n") | |
| project_root = Path(__file__).parent.parent.parent.parent.parent | |
| golden_dir = project_root / "tests" / "golden" | |
| if not golden_dir.exists(): | |
| console.print("[yellow]⚠ Golden references directory not found[/yellow]") | |
| console.print(f"Expected: {golden_dir}") | |
| return | |
| # Find version directories | |
| version_dirs = [] | |
| for item in golden_dir.iterdir(): | |
| if item.is_dir() and item.name.startswith('v'): | |
| version_dirs.append(item) | |
| if not version_dirs: | |
| console.print("[yellow]⚠ No baselines found[/yellow]") | |
| console.print(f"\nGenerate a baseline with: hae verify create-baseline --version v1.0") | |
| return | |
| # Sort by version number | |
| def parse_version(name): | |
| try: | |
| parts = name[1:].split('.') | |
| return tuple(int(p) for p in parts) | |
| except ValueError: | |
| return (0, 0, 0) | |
| version_dirs.sort(key=lambda p: parse_version(p.name), reverse=True) | |
| # Find current baseline | |
| current_link = golden_dir / "current" | |
| current_version = None | |
| if current_link.exists(): | |
| if current_link.is_symlink(): | |
| current_version = current_link.resolve().name | |
| elif current_link.is_dir(): | |
| current_version = "current" | |
| # Display baselines | |
| import json as json_lib | |
| baselines = [] | |
| for version_dir in version_dirs: | |
| metadata_path = version_dir / "metadata.json" | |
| baseline_info = { | |
| 'version': version_dir.name, | |
| 'current': version_dir.name == current_version, | |
| 'path': str(version_dir) | |
| } | |
| if metadata_path.exists(): | |
| try: | |
| with open(metadata_path, 'r') as f: | |
| metadata = json_lib.load(f) | |
| baseline_info['date'] = metadata.get('date', 'unknown') | |
| baseline_info['description'] = metadata.get('description', '') | |
| baseline_info['git_commit'] = metadata.get('git_commit', 'unknown')[:8] | |
| baseline_info['git_branch'] = metadata.get('git_branch', 'unknown') | |
| except Exception as e: | |
| baseline_info['error'] = str(e) | |
| baselines.append(baseline_info) | |
| # Output | |
| if output_json: | |
| console.print(json_lib.dumps(baselines, indent=2)) | |
| else: | |
| table = Table(title="Golden Reference Baselines") | |
| table.add_column("Version", style="cyan") | |
| table.add_column("Current", style="green") | |
| table.add_column("Date", style="yellow") | |
| table.add_column("Description") | |
| table.add_column("Git Commit", style="magenta") | |
| for baseline in baselines: | |
| current_marker = "✓" if baseline['current'] else "" | |
| date = baseline.get('date', 'unknown') | |
| if date != 'unknown' and 'T' in date: | |
| date = date.split('T')[0] # Show only date part | |
| table.add_row( | |
| baseline['version'], | |
| current_marker, | |
| date, | |
| baseline.get('description', ''), | |
| baseline.get('git_commit', 'unknown') | |
| ) | |
| console.print(table) | |
| if verbose: | |
| console.print("\n[bold]Details:[/bold]") | |
| for baseline in baselines: | |
| console.print(f"\n[cyan]{baseline['version']}[/cyan]") | |
| console.print(f" Path: {baseline['path']}") | |
| if 'git_branch' in baseline: | |
| console.print(f" Branch: {baseline['git_branch']}") | |
| if 'error' in baseline: | |
| console.print(f" [red]Error: {baseline['error']}[/red]") | |
| console.print(f"\n[dim]Use 'hae verify compare' to compare baselines[/dim]") | |
| @verify.command('run') | |
| @click.option('--baseline', required=True, help='Baseline version to test against (e.g., v1.0)') | |
| @click.option('--threshold', type=float, default=1e-5, | |
| help='MAE threshold for pass/fail (default: 1e-5)') | |
| @click.option('--full', is_flag=True, | |
| help='Run full comparison including cross-method validation') | |
| @click.option('--visualize', is_flag=True, default=True, | |
| help='Generate visualizations (default: enabled)') | |
| @click.option('--dry-run/--no-dry-run', default=False, | |
| help='Show command without executing') | |
| @click.option('--use-script', is_flag=True, | |
| help='Use legacy run_regression_test.py script (fallback mode)') | |
| def verify_run(baseline, threshold, full, visualize, dry_run, use_script): | |
| """Run regression test suite comparing against golden baseline. | |
| By default, runs both decodiff_evaluator and anomaly_detector vs baseline (quick mode). | |
| Use --full for comprehensive testing including cross-method validation. | |
| Uses pipeline-based experiment system for declarative, reproducible testing. | |
| Test Flow (Quick Mode - Default): | |
| [1/6] Create test environment | |
| [2/6] Run decodiff_evaluator evaluation (hae eval) | |
| [3/6] Run anomaly_detector evaluation (hae anomaly-detect) | |
| [4/6] Compare decodiff_evaluator vs baseline | |
| [5/6] Compare anomaly_detector vs baseline | |
| [6/6] Generate quick mode summary | |
| Test Flow (Full Mode - --full): | |
| [1/7] Create test environment | |
| [2/7] Run decodiff_evaluator evaluation | |
| [3/7] Run anomaly_detector evaluation | |
| [4/7] Cross-validate decodiff_evaluator vs anomaly_detector | |
| [5/7] Compare decodiff_evaluator vs baseline | |
| [6/7] Compare anomaly_detector vs baseline | |
| [7/7] Generate full mode summary | |
| Examples: | |
| # Quick test (both evaluators vs baseline, ~2-3 minutes) | |
| hae verify run --baseline v1.0 | |
| # Full test (includes cross-validation, ~5-10 minutes) | |
| hae verify run --baseline v1.0 --full | |
| # Custom threshold | |
| hae verify run --baseline v1.0 --threshold 1e-6 | |
| # Dry run to see what will execute | |
| hae verify run --baseline v1.0 --dry-run | |
| # Use legacy script (fallback) | |
| hae verify run --baseline v1.0 --use-script | |
| """ | |
| console.print("[cyan]🧪 Running Regression Test Suite[/cyan]\n") | |
| project_root = Path(__file__).parent.parent.parent.parent.parent | |
| # Determine mode and experiment name | |
| if use_script: | |
| # Legacy mode: use run_regression_test.py script | |
| tests_dir = project_root / "workspaces" / "hae-py" / "tests" | |
| script_path = tests_dir / "run_regression_test.py" | |
| if not script_path.exists(): | |
| console.print(f"[red]✗ Regression test script not found: {script_path}[/red]") | |
| sys.exit(1) | |
| # Build command | |
| cmd = [ | |
| "py", "-3.11", str(script_path), | |
| "--baseline", baseline, | |
| "--threshold", str(threshold) | |
| ] | |
| if full: | |
| cmd.append("--full") | |
| if visualize: | |
| cmd.append("--visualize") | |
| # Show standalone command | |
| show_standalone_alternatives(cmd, "Run regression test suite (legacy script)") | |
| if dry_run: | |
| console.print("[yellow]Dry run - not executing[/yellow]") | |
| return | |
| # Execute regression test | |
| console.print(f"[bold]Mode:[/bold] {'Full Comparison' if full else 'Quick Test'} (legacy script)") | |
| console.print(f"[bold]Baseline:[/bold] {baseline}") | |
| console.print(f"[bold]Threshold:[/bold] {threshold:.0e}") | |
| console.print() | |
| result = subprocess.run(cmd) | |
| sys.exit(result.returncode) | |
| else: | |
| # New mode: use pipeline-based experiments | |
| experiment_name = 'regression-test-full' if full else 'regression-test-quick' | |
| # Build hae experiment run command | |
| cmd = [ | |
| "hae", "experiment", "run", experiment_name, | |
| "--variables", f"baseline_dir=./fixtures/golden/{baseline}", | |
| "--variables", f"threshold={threshold}", | |
| "--variables", f"visualize={'true' if visualize else 'false'}" | |
| ] | |
| # Show standalone command | |
| show_standalone_alternatives(cmd, "Run regression test suite (pipeline-based)") | |
| if dry_run: | |
| console.print("[yellow]Dry run - not executing[/yellow]") | |
| console.print(f"\n[bold]Would run experiment:[/bold] {experiment_name}") | |
| console.print(f"[bold]Variables:[/bold]") | |
| console.print(f" baseline_version: {baseline}") | |
| console.print(f" threshold: {threshold}") | |
| console.print(f" visualize: {visualize}") | |
| return | |
| # Execute experiment | |
| console.print(f"[bold]Mode:[/bold] {'Full Comparison' if full else 'Quick Test'} (pipeline-based)") | |
| console.print(f"[bold]Experiment:[/bold] {experiment_name}") | |
| console.print(f"[bold]Baseline:[/bold] {baseline}") | |
| console.print(f"[bold]Threshold:[/bold] {threshold:.0e}") | |
| console.print() | |
| result = subprocess.run(cmd) | |
| sys.exit(result.returncode) | |
| @verify.command('compare-fundamentals') | |
| @click.option('--test-dir', required=True, help='Test output directory') | |
| @click.option('--baseline-dir', required=True, help='Baseline reference directory') | |
| @click.option('--output-dir', required=True, help='Comparison results output directory') | |
| @click.option('--threshold', type=float, default=1e-5, help='MAE threshold for pass/fail') | |
| @click.option('--visualize/--no-visualize', default=True, help='Generate visualization plots') | |
| @click.option('--dry-run/--no-dry-run', default=False, help='Preview command without executing') | |
| def verify_compare_fundamentals(test_dir, baseline_dir, output_dir, threshold, visualize, dry_run): | |
| """Compare fundamental NPY variables against baseline. | |
| Compares x0, encoded, image_samples, and latent_samples between | |
| test outputs and golden baseline, computing MAE, MSE, RMSE, max_diff, | |
| and Pearson correlation metrics. | |
| Examples: | |
| hae verify compare-fundamentals \\ | |
| --test-dir ./outputs/decodiff_evaluator \\ | |
| --baseline-dir ./fixtures/golden/v1.0/decodiff_evaluator \\ | |
| --output-dir ./comparison/decodiff_evaluator | |
| """ | |
| console.print("[bold cyan]📊 Comparing Fundamental NPY Variables[/bold cyan]\n") | |
| # Build command | |
| cmd = [ | |
| "python", "workspaces/hae-py/tests/compare_fundamentals.py", | |
| "--test_dir", test_dir, | |
| "--baseline_dir", baseline_dir, | |
| "--output_dir", output_dir, | |
| "--threshold", str(threshold), | |
| ] | |
| if visualize: | |
| cmd.append("--visualize") | |
| # Show standalone alternatives | |
| show_standalone_alternatives(cmd, "Compare fundamental NPY variables") | |
| if dry_run: | |
| console.print("[yellow]Dry run - not executing[/yellow]") | |
| return | |
| console.print(f"[bold]Test directory:[/bold] {test_dir}") | |
| console.print(f"[bold]Baseline directory:[/bold] {baseline_dir}") | |
| console.print(f"[bold]Output directory:[/bold] {output_dir}") | |
| console.print(f"[bold]Threshold:[/bold] MAE < {threshold:.2e}") | |
| console.print(f"[bold]Visualize:[/bold] {visualize}") | |
| console.print() | |
| result = subprocess.run(cmd) | |
| sys.exit(result.returncode) | |
| @verify.command('compare-cross-methods') | |
| @click.option('--decodiff-evaluator-dir', 'decodiff_evaluator_dir', required=True, | |
| help='decodiff_evaluator output directory') | |
| @click.option('--anomaly-detector-dir', 'anomaly_detector_dir', required=True, | |
| help='anomaly_detector output directory') | |
| @click.option('--output-dir', required=True, help='Comparison results output directory') | |
| @click.option('--threshold', type=float, default=1e-5, help='MAE threshold for pass/fail') | |
| @click.option('--visualize/--no-visualize', default=True, help='Generate visualization plots') | |
| @click.option('--dry-run/--no-dry-run', default=False, help='Preview command without executing') | |
| def verify_compare_cross_methods(decodiff_evaluator_dir, anomaly_detector_dir, output_dir, | |
| threshold, visualize, dry_run): | |
| """Cross-validate decodiff_evaluator vs anomaly_detector outputs. | |
| Compares fundamental variables and derived anomaly maps between | |
| both evaluation methods to ensure algorithmic equivalence. | |
| Examples: | |
| hae verify compare-cross-methods \\ | |
| --decodiff-evaluator-dir ./outputs/decodiff_evaluator \\ | |
| --anomaly-detector-dir ./outputs/anomaly_detector \\ | |
| --output-dir ./comparison/cross_method | |
| """ | |
| console.print("[bold cyan]🔄 Cross-Validating Evaluation Methods[/bold cyan]\n") | |
| # Build command | |
| cmd = [ | |
| "python", "workspaces/hae-py/tests/compare_cross_methods.py", | |
| "--decodiff_evaluator_dir", decodiff_evaluator_dir, | |
| "--anomaly_detector_dir", anomaly_detector_dir, | |
| "--output_dir", output_dir, | |
| "--threshold", str(threshold), | |
| ] | |
| if visualize: | |
| cmd.append("--visualize") | |
| # Show standalone alternatives | |
| show_standalone_alternatives(cmd, "Cross-validate evaluation methods") | |
| if dry_run: | |
| console.print("[yellow]Dry run - not executing[/yellow]") | |
| return | |
| console.print(f"[bold]decodiff_evaluator directory:[/bold] {decodiff_evaluator_dir}") | |
| console.print(f"[bold]anomaly_detector directory:[/bold] {anomaly_detector_dir}") | |
| console.print(f"[bold]Output directory:[/bold] {output_dir}") | |
| console.print(f"[bold]Threshold:[/bold] MAE < {threshold:.2e}") | |
| console.print(f"[bold]Visualize:[/bold] {visualize}") | |
| console.print() | |
| result = subprocess.run(cmd) | |
| sys.exit(result.returncode) | |
| @verify.command('summary') | |
| @click.option('--mode', type=click.Choice(['quick', 'full']), required=True, | |
| help='Summary mode: quick (baseline only) or full (with cross-validation)') | |
| @click.option('--output-dir', required=True, help='Output directory for summary file') | |
| @click.option('--baseline-version', required=True, help='Baseline version tested against') | |
| @click.option('--decodiff-evaluator-results', 'decodiff_evaluator_results', required=True, | |
| help='Path to decodiff_evaluator comparison results JSON') | |
| @click.option('--anomaly-detector-results', 'anomaly_detector_results', required=True, | |
| help='Path to anomaly_detector comparison results JSON') | |
| @click.option('--cross-method-results', 'cross_method_results', default=None, | |
| help='Path to cross-method comparison results JSON (full mode only)') | |
| @click.option('--dry-run/--no-dry-run', default=False, help='Preview command without executing') | |
| def verify_summary(mode, output_dir, baseline_version, decodiff_evaluator_results, | |
| anomaly_detector_results, cross_method_results, dry_run): | |
| """Generate regression test summary report. | |
| Creates a human-readable summary of regression test results, | |
| including pass/fail status, timing information, and key metrics. | |
| Examples: | |
| # Quick mode summary | |
| hae verify summary --mode quick \\ | |
| --output-dir ./outputs \\ | |
| --baseline-version v1.0 \\ | |
| --decodiff-evaluator-results ./comparison/decodiff_evaluator/comparison_results.json \\ | |
| --anomaly-detector-results ./comparison/anomaly_detector/comparison_results.json | |
| # Full mode summary (with cross-validation) | |
| hae verify summary --mode full \\ | |
| --output-dir ./outputs \\ | |
| --baseline-version v1.0 \\ | |
| --decodiff-evaluator-results ./comparison/decodiff_evaluator/comparison_results.json \\ | |
| --anomaly-detector-results ./comparison/anomaly_detector/comparison_results.json \\ | |
| --cross-method-results ./comparison/cross_method/cross_method_results.json | |
| """ | |
| console.print("[bold cyan]📋 Generating Regression Test Summary[/bold cyan]\n") | |
| # Build command | |
| cmd = [ | |
| "python", "workspaces/hae-py/tests/generate_regression_summary.py", | |
| "--mode", mode, | |
| "--output-dir", output_dir, | |
| "--baseline-version", baseline_version, | |
| "--decodiff-evaluator-results", decodiff_evaluator_results, | |
| "--anomaly-detector-results", anomaly_detector_results, | |
| ] | |
| if cross_method_results: | |
| cmd.extend(["--cross-method-results", cross_method_results]) | |
| # Show standalone alternatives | |
| show_standalone_alternatives(cmd, "Generate regression test summary") | |
| if dry_run: | |
| console.print("[yellow]Dry run - not executing[/yellow]") | |
| return | |
| console.print(f"[bold]Mode:[/bold] {mode}") | |
| console.print(f"[bold]Output directory:[/bold] {output_dir}") | |
| console.print(f"[bold]Baseline version:[/bold] {baseline_version}") | |
| console.print(f"[bold]decodiff_evaluator results:[/bold] {decodiff_evaluator_results}") | |
| console.print(f"[bold]anomaly_detector results:[/bold] {anomaly_detector_results}") | |
| if cross_method_results: | |
| console.print(f"[bold]Cross-method results:[/bold] {cross_method_results}") | |
| console.print() | |
| result = subprocess.run(cmd) | |
| sys.exit(result.returncode) | |
| @main.command('info') | |
| def info(): | |
| """Display DiOodMi project information and status. | |
| Shows workspace structure, available models, datasets, and configuration. | |
| """ | |
| console.print("[bold]DiOodMi - HAEDOSA Automation Engine[/bold]\n") | |
| # Project info | |
| table = Table(title="Project Information") | |
| table.add_column("Property", style="cyan") | |
| table.add_column("Value", style="green") | |
| table.add_row("Version", "0.3.0") | |
| table.add_row("Architecture", "Export-based (hybrid)") | |
| table.add_row("Core Verbs", "dioodmi-train, dioodmi-eval, dioodmi-reconstruct") | |
| table.add_row("Python Package", "dioodmi") | |
| table.add_row("Standards", "HAEDOSA MS04, MS07, MS12") | |
| console.print(table) | |
| # Available models | |
| console.print("\n[bold]Available Models:[/bold]") | |
| console.print(" • DDPM (Pixel-space diffusion)") | |
| console.print(" • LDM (Latent diffusion with VQVAE)") | |
| console.print(" • DeCo-Diff (Deviation prediction)") | |
| console.print(" • INFD (Image Neural Field Diffusion)") | |
| # Available datasets | |
| console.print("\n[bold]Supported Datasets:[/bold]") | |
| console.print(" • MVTec-AD (5,346 images, 15 categories)") | |
| console.print(" • VisA (10,822 images, 12 categories)") | |
| console.print(" • Custom (CSV-based format)") | |
| # Quick commands | |
| console.print("\n[bold]Quick Commands:[/bold]") | |
| console.print(" [cyan]# Training & Evaluation[/cyan]") | |
| console.print(" hae train --model decodiff --dataset mvtec --category bottle") | |
| console.print(" hae eval --checkpoint models/bottle.pth --dataset mvtec") | |
| console.print() | |
| console.print(" [cyan]# OOD Detection[/cyan]") | |
| console.print(" hae ood-detection --output-dir ./results --model-name decodiff_mvtec") | |
| console.print(" hae ood-results --output-dir ./results --model-name decodiff_mvtec") | |
| console.print(" hae ood-graph --in-csv results/in.csv --out-csv results/out.csv") | |
| console.print() | |
| console.print(" [cyan]# Data Preprocessing[/cyan]") | |
| console.print(" hae preprocess crop --input-dir ./data --output-dir ./crops --patch-size 128 --patches-per-image 2") | |
| console.print(" hae preprocess split ./data/images --valid-ratio 0.1") | |
| console.print(" hae preprocess annotate -i ./images -o ./annotations") | |
| console.print() | |
| console.print(" [cyan]# Data Processing & Analysis[/cyan]") | |
| console.print(" hae data gen-masks --model-name bottle --output-dir ./masks") | |
| console.print(" hae data compare-contours --gt truth.jsonl --pred pred.jsonl --out-dir ./eval") | |
| console.print(" hae data copy-from-csv data/test.csv ./test_images") | |
| console.print() | |
| console.print(" [cyan]# Model & Checkpoint Management[/cyan]") | |
| console.print(" hae model info ./checkpoints/best.pt") | |
| console.print(" hae model utils ./checkpoints --list") | |
| console.print(" hae model save-ddpm --model-name bottle --output-dir ./models") | |
| console.print() | |
| console.print(" [cyan]# Development Tools[/cyan]") | |
| console.print(" hae dev check-cuda") | |
| console.print(" hae dev sort-interactive --pred-jsonl pred.jsonl --out-dir ./sorted") | |
| console.print(" hae dev jupyter") | |
| console.print() | |
| console.print(" [cyan]# System Diagnostics[/cyan]") | |
| console.print(" hae system check") | |
| console.print(" hae system check --verbose") | |
| console.print() | |
| console.print(" [cyan]# Regression Testing & Baselines[/cyan]") | |
| console.print(" hae verify list") | |
| console.print(" hae verify run --baseline v1.0") | |
| console.print(" hae verify create-baseline --version v1.0 --method both") | |
| console.print(" hae verify compare tests/golden/v1.0 tests/golden/v1.1") | |
| console.print(" hae verify approve-baseline --version v1.1") | |
| @main.command() | |
| @click.option('--full', is_flag=True, help='Full rebuild (clears and re-indexes everything)') | |
| @click.option('--verify', is_flag=True, help='Verify database integrity') | |
| def sync(full, verify): | |
| """Sync experiment database with YAML files and execution history. | |
| Convenience shortcut for 'hae experiment sync'. | |
| By default, performs smart incremental sync (only changed files). | |
| Safe to run frequently (e.g., after git pull or running experiments). | |
| The database can always be rebuilt from YAML files (ground truth). | |
| Feel free to delete .hae/ directory - sync will recreate it. | |
| Examples: | |
| hae sync # Smart sync (fast, incremental) | |
| hae sync --full # Full rebuild (slower) | |
| hae sync --verify # Check database integrity | |
| For more experiment management commands, see 'hae experiment --help'. | |
| """ | |
| # Call the experiment sync function directly | |
| from click import Context | |
| ctx = Context(experiment_sync) | |
| ctx.invoke(experiment_sync, full=full, verify=verify) | |
| @main.command() | |
| @click.option('--migration-only', is_flag=True, help='Check only migration status (quick)') | |
| @click.option('--experiments', is_flag=True, help='Check only experiment system health') | |
| @click.option('--system', is_flag=True, help='Check only system environment (CUDA, datasets)') | |
| @click.option('--fix', is_flag=True, help='Auto-fix issues where possible') | |
| @click.option('--verbose', is_flag=True, help='Show detailed diagnostics') | |
| def doctor(migration_only, experiments, system, fix, verbose): | |
| """Comprehensive health check for DiOodMi project. | |
| Performs a complete diagnostic of the DiOodMi system including: | |
| - Experiment system migration status and validation | |
| - System environment (CUDA availability, datasets) | |
| - Database integrity and synchronization | |
| - File references and dependencies | |
| This is the recommended command to run after pulling new changes. | |
| Examples: | |
| hae doctor # Full health check | |
| hae doctor --migration-only # Quick migration check | |
| hae doctor --experiments # Only experiment system | |
| hae doctor --system # Only system environment | |
| hae doctor --fix # Auto-repair issues | |
| hae doctor --verbose # Detailed diagnostics | |
| """ | |
| import sys | |
| from pathlib import Path | |
| from rich.table import Table | |
| console.print("[bold cyan]🏥 DiOodMi System Health Check[/bold cyan]\n") | |
| # Track overall status | |
| total_errors = 0 | |
| total_warnings = 0 | |
| # Determine which checks to run | |
| run_all = not (migration_only or experiments or system) | |
| run_experiments = run_all or migration_only or experiments | |
| run_system_check = run_all or system | |
| # ======================================================================== | |
| # 1. Experiment System Health Check | |
| # ======================================================================== | |
| if run_experiments: | |
| console.print("[bold]Experiment System:[/bold]") | |
| # Import experiment doctor functionality | |
| from hae.migration import ( | |
| check_migration_status, | |
| find_old_path_references, | |
| validate_yaml_syntax, | |
| validate_dependencies, | |
| validate_file_references, | |
| fix_yaml_paths | |
| ) | |
| from hae.db import ExperimentDB | |
| errors = 0 | |
| warnings = 0 | |
| # 1.1 Migration Status | |
| status = check_migration_status() | |
| if status.old_definitions_exists: | |
| console.print(f" [red]❌ OLD:[/red] experiments/definitions/ ({status.old_definitions_count} files)") | |
| errors += 1 | |
| if status.old_history_exists: | |
| console.print(f" [red]❌ OLD:[/red] experiments/history/ ({status.old_history_count} records)") | |
| errors += 1 | |
| if status.new_definitions_exists: | |
| console.print(f" [green]✅ NEW:[/green] definitions/ ({status.new_definitions_count} files)") | |
| else: | |
| console.print(f" [yellow]⚠️ NEW:[/yellow] definitions/ not found") | |
| warnings += 1 | |
| if status.new_history_exists: | |
| console.print(f" [dim] NEW: experiment-history/ ({status.new_history_count} records)[/dim]") | |
| if status.needs_migration: | |
| console.print(f" [cyan]💡 Action:[/cyan] Run 'hae experiment migrate'\n") | |
| else: | |
| console.print(f" [green]✓ Migration complete[/green]\n") | |
| # If migration-only flag, stop here | |
| if migration_only: | |
| summary = f"Migration {'needed' if status.needs_migration else 'complete'}" | |
| console.print(f"[dim]{summary}[/dim]") | |
| sys.exit(1 if status.needs_migration else 0) | |
| # 1.2 Experiment Definitions Validation | |
| yaml_dir = Path("definitions/experiments") | |
| if yaml_dir.exists(): | |
| # YAML syntax validation | |
| yaml_files = list(yaml_dir.glob("*.yaml")) | |
| valid_count = 0 | |
| invalid_files = [] | |
| for yaml_file in yaml_files: | |
| result = validate_yaml_syntax(yaml_file) | |
| if result.valid: | |
| valid_count += 1 | |
| else: | |
| invalid_files.append((yaml_file.name, result.error)) | |
| if invalid_files: | |
| console.print(f" [red]❌ YAML Syntax ({valid_count}/{len(yaml_files)} valid)[/red]") | |
| errors += len(invalid_files) | |
| else: | |
| console.print(f" [green]✅ YAML Syntax ({len(yaml_files)}/{len(yaml_files)} valid)[/green]") | |
| # Dependency validation | |
| dep_issues = validate_dependencies(yaml_dir) | |
| if dep_issues: | |
| console.print(f" [red]❌ Dependencies ({len(dep_issues)} broken refs)[/red]") | |
| errors += len(dep_issues) | |
| else: | |
| console.print(f" [green]✅ Dependencies (no broken references)[/green]") | |
| # File reference validation | |
| file_issues = validate_file_references(yaml_dir) | |
| if file_issues: | |
| console.print(f" [yellow]⚠️ Missing Files ({len(file_issues)} warnings)[/yellow]") | |
| warnings += len(file_issues) | |
| else: | |
| console.print(f" [green]✅ File References (all files found)[/green]") | |
| # Old path references | |
| path_issues = find_old_path_references(yaml_dir) | |
| if path_issues: | |
| console.print(f" [yellow]⚠️ Old Paths ({len(path_issues)} found)[/yellow]") | |
| warnings += len(path_issues) | |
| if fix: | |
| console.print(f" [cyan]🔧 Fixing path references...[/cyan]") | |
| fixed_count, actions = fix_yaml_paths(yaml_dir, dry_run=False) | |
| console.print(f" [green]✓ Fixed {fixed_count} files[/green]") | |
| warnings -= len(path_issues) # Remove warnings since we fixed them | |
| # 1.3 Database Status | |
| db_path = Path(".hae/experiments.db") | |
| if not db_path.exists(): | |
| console.print(" [yellow]⚠️ .hae/experiments.db not found[/yellow]") | |
| console.print(" [cyan]💡 Action:[/cyan] Run 'hae experiment sync'") | |
| warnings += 1 | |
| else: | |
| try: | |
| db = ExperimentDB() | |
| def_count = db.cursor.execute("SELECT COUNT(*) FROM experiment_definitions").fetchone()[0] | |
| exec_count = db.cursor.execute("SELECT COUNT(*) FROM experiment_executions").fetchone()[0] | |
| if def_count == 0 and status.new_definitions_exists: | |
| console.print(f" [yellow]⚠️ Database empty (0 definitions)[/yellow]") | |
| console.print(" [cyan]💡 Action:[/cyan] Run 'hae experiment sync'") | |
| warnings += 1 | |
| else: | |
| console.print(f" [green]✅ Database OK ({def_count} definitions, {exec_count} executions)[/green]") | |
| db.close() | |
| except Exception as e: | |
| console.print(f" [red]❌ Database error: {e}[/red]") | |
| errors += 1 | |
| # 1.4 Format Validation (experiment.org must be .org, not .md) | |
| history_dir = Path("experiment-history") | |
| if history_dir.exists(): | |
| # Check for .md files in experiment-history/ (should be .org) | |
| md_files = list(history_dir.glob("**/*.md")) | |
| # Filter to only experiment.md files (not other docs) | |
| experiment_md_files = [f for f in md_files if f.name in ["experiment.md", "notes.md"]] | |
| if experiment_md_files: | |
| console.print(f" [red]❌ Format Violation ({len(experiment_md_files)} .md files in experiment-history/)[/red]") | |
| for md_file in experiment_md_files[:3]: # Show first 3 | |
| console.print(f" • {md_file}") | |
| if len(experiment_md_files) > 3: | |
| console.print(f" ... and {len(experiment_md_files) - 3} more") | |
| console.print(" [cyan]💡 Action:[/cyan] Experiment records MUST be .org (org-roam :ID: required)") | |
| console.print(" [cyan]💡 See:[/cyan] docs/DOCUMENTATION-STANDARDS.org") | |
| errors += len(experiment_md_files) | |
| else: | |
| console.print(f" [green]✅ Format Compliance (experiment records are .org)[/green]") | |
| console.print() | |
| total_errors += errors | |
| total_warnings += warnings | |
| # ======================================================================== | |
| # 2. System Environment Check | |
| # ======================================================================== | |
| if run_system_check: | |
| console.print("[bold]System Environment:[/bold]") | |
| env_errors = 0 | |
| env_warnings = 0 | |
| # 2.1 CUDA availability | |
| try: | |
| import torch | |
| if torch.cuda.is_available(): | |
| cuda_version = torch.version.cuda | |
| device_count = torch.cuda.device_count() | |
| device_name = torch.cuda.get_device_name(0) if device_count > 0 else "N/A" | |
| console.print(f" [green]✅ CUDA {cuda_version} ({device_count} device(s): {device_name})[/green]") | |
| else: | |
| console.print(f" [yellow]⚠️ CUDA not available (CPU-only mode)[/yellow]") | |
| env_warnings += 1 | |
| except ImportError: | |
| console.print(f" [red]❌ PyTorch not installed[/red]") | |
| env_errors += 1 | |
| # 2.2 Dataset availability | |
| datasets_dir = Path("datasets") | |
| if datasets_dir.exists(): | |
| # Count both raw and processed subdirectories | |
| raw_dir = datasets_dir / "raw" | |
| processed_dir = datasets_dir / "processed" | |
| dataset_count = 0 | |
| if raw_dir.exists(): | |
| dataset_count += sum(1 for d in raw_dir.iterdir() if d.is_dir()) | |
| if processed_dir.exists(): | |
| dataset_count += sum(1 for d in processed_dir.iterdir() if d.is_dir()) | |
| console.print(f" [green]✅ Datasets ({dataset_count} found in datasets/)[/green]") | |
| else: | |
| console.print(f" [yellow]⚠️ datasets/ not found[/yellow]") | |
| env_warnings += 1 | |
| # 2.3 Core dependencies | |
| try: | |
| import dioodmi | |
| console.print(f" [green]✅ DiOodMi core package installed[/green]") | |
| except ImportError: | |
| console.print(f" [red]❌ DiOodMi core package not installed[/red]") | |
| console.print(f" [cyan]💡 Action:[/cyan] Run 'uv pip install -e workspaces/dioodmi-py'") | |
| env_errors += 1 | |
| console.print() | |
| total_errors += env_errors | |
| total_warnings += env_warnings | |
| # ======================================================================== | |
| # 3. Summary | |
| # ======================================================================== | |
| console.print("[bold]Summary:[/bold]") | |
| if total_errors == 0 and total_warnings == 0: | |
| console.print("[green]✅ All healthy - no issues found![/green]") | |
| sys.exit(0) | |
| elif total_errors == 0: | |
| console.print(f"[yellow]⚠️ {total_warnings} warning(s) found[/yellow]") | |
| sys.exit(0) | |
| else: | |
| console.print(f"[red]❌ {total_errors} error(s), {total_warnings} warning(s) found[/red]") | |
| if not fix: | |
| console.print("[dim]Run with --fix to attempt automatic repairs[/dim]") | |
| sys.exit(1) | |
| console.print() | |
| console.print(" [cyan]# Utilities[/cyan]") | |
| console.print(" hae examples list") | |
| console.print(" hae dvc status") | |
| console.print(" hae dvc exp show") | |
| console.print("\n[dim]Use 'hae --help' for more information[/dim]") | |
| @main.group() | |
| def vision(): | |
| """Vision processing tools. | |
| Commands for image segmentation (SAM) and visual analysis (Ollama). | |
| """ | |
| pass | |
| @vision.command('segment') | |
| @click.argument('image', type=click.Path(exists=True)) | |
| @click.option('--model', type=click.Choice(['vit_b', 'vit_l', 'vit_h']), | |
| default='vit_b', help='SAM model size (default: vit_b)') | |
| @click.option('--checkpoint-dir', type=click.Path(), | |
| help='Directory containing SAM checkpoints') | |
| @click.option('--threshold', type=float, default=0.5, | |
| help='IoU threshold for mask filtering (default: 0.5)') | |
| @click.option('--points-per-side', type=int, default=32, | |
| help='Points per side for segmentation (default: 32, higher=better quality)') | |
| @click.option('--combine', type=click.Choice(['union', 'largest']), default='union', | |
| help='Combine all masks or only largest (default: union)') | |
| @click.option('--output-dir', type=click.Path(), default='./output', | |
| help='Output directory for masks (default: ./output)') | |
| @click.option('--save-binary/--no-binary', default=True, | |
| help='Save binary mask (default: True)') | |
| @click.option('--save-instance/--no-instance', default=False, | |
| help='Save instance mask (default: False)') | |
| @click.option('--max-coverage', type=float, default=95.0, | |
| help='Exclude objects covering more than this % of image (default: 95)') | |
| @click.option('--save-individual/--no-individual', default=False, | |
| help='Save each object mask as separate file for debugging (default: False)') | |
| @click.option('--save-overlay/--no-overlay', default=False, | |
| help='Save colored masks overlaid on original image (default: False)') | |
| def vision_segment(image, model, checkpoint_dir, threshold, points_per_side, | |
| combine, output_dir, save_binary, save_instance, max_coverage, save_individual, save_overlay): | |
| """Segment objects in an image using SAM (Segment Anything Model). | |
| Examples: | |
| hae vision segment image.bmp | |
| hae vision segment image.bmp --model vit_l --threshold 0.9 | |
| hae vision segment image.bmp --save-instance --output-dir ./masks | |
| """ | |
| from pathlib import Path | |
| from .vision import segment_image | |
| console.print("[bold cyan]🎯 SAM Image Segmentation[/bold cyan]\n") | |
| try: | |
| image_path = Path(image) | |
| output_path = Path(output_dir) | |
| output_path.mkdir(parents=True, exist_ok=True) | |
| console.print(f" Image: {image_path.name}") | |
| console.print(f" Model: {model}") | |
| console.print(f" Threshold: {threshold}") | |
| console.print(f" Points per side: {points_per_side}\n") | |
| with console.status("[bold green]Segmenting image..."): | |
| result = segment_image( | |
| image_path, | |
| model_type=model, | |
| checkpoint_dir=Path(checkpoint_dir) if checkpoint_dir else None, | |
| points_per_side=points_per_side, | |
| threshold=threshold, | |
| output_binary=save_binary, | |
| output_instance=save_instance, | |
| combine=combine, | |
| ) | |
| # Save masks | |
| import cv2 | |
| image_name = image_path.stem | |
| if save_binary and result.get('binary_mask') is not None: | |
| binary_file = output_path / f"{image_name}_binary.png" | |
| cv2.imwrite(str(binary_file), result['binary_mask']) | |
| console.print(f" ✓ Binary mask: {binary_file}") | |
| if save_instance and result.get('instance_mask') is not None: | |
| # Filter out objects that cover too much of the image | |
| total_pixels = result['image_shape'][0] * result['image_shape'][1] | |
| filtered_masks = [ | |
| m for m in result['raw_masks'] | |
| if (m['area'] / total_pixels * 100) <= max_coverage | |
| ] | |
| # Create colored version for visualization | |
| from .vision.segment import create_colored_instance_mask | |
| colored_mask = create_colored_instance_mask(filtered_masks, threshold) | |
| instance_file = output_path / f"{image_name}_instance.png" | |
| cv2.imwrite(str(instance_file), colored_mask) | |
| console.print(f" ✓ Instance mask: {instance_file}") | |
| excluded_count = len(result['raw_masks']) - len(filtered_masks) | |
| if excluded_count > 0: | |
| console.print(f" ℹ Excluded {excluded_count} object(s) covering >{max_coverage}% of image") | |
| # Save individual masks for debugging | |
| if save_individual and result['num_objects'] > 0: | |
| import numpy as np | |
| total_pixels = result['image_shape'][0] * result['image_shape'][1] | |
| sorted_masks = sorted(result['raw_masks'], key=lambda x: x['area'], reverse=True) | |
| console.print(f"\n [bold]Saving individual object masks...[/bold]") | |
| for idx, mask in enumerate(sorted_masks): | |
| seg = mask['segmentation'] | |
| # Create white mask on black background | |
| individual_mask = np.zeros(seg.shape, dtype=np.uint8) | |
| individual_mask[seg] = 255 | |
| coverage = (mask['area'] / total_pixels) * 100 | |
| individual_file = output_path / f"{image_name}_object{idx+1}_{coverage:.1f}pct.png" | |
| cv2.imwrite(str(individual_file), individual_mask) | |
| console.print(f" Object {idx+1}: {individual_file.name}") | |
| # Save overlay visualization | |
| if save_overlay and result['num_objects'] > 0: | |
| import numpy as np | |
| from PIL import Image | |
| # Load original image | |
| original = Image.open(image_path) | |
| original_rgb = np.array(original.convert("RGB")) | |
| # Filter masks by coverage | |
| total_pixels = result['image_shape'][0] * result['image_shape'][1] | |
| filtered_masks = [ | |
| m for m in result['raw_masks'] | |
| if (m['area'] / total_pixels * 100) <= max_coverage | |
| ] | |
| # Create colored mask overlay | |
| from .vision.segment import create_colored_instance_mask | |
| colored_mask = create_colored_instance_mask(filtered_masks, threshold) | |
| # Blend with original image (50% transparency) | |
| alpha = 0.5 | |
| overlay = cv2.addWeighted(original_rgb, 1 - alpha, colored_mask, alpha, 0) | |
| overlay_file = output_path / f"{image_name}_overlay.png" | |
| cv2.imwrite(str(overlay_file), cv2.cvtColor(overlay, cv2.COLOR_RGB2BGR)) | |
| console.print(f" ✓ Overlay: {overlay_file}") | |
| # Show object sizes for debugging | |
| console.print(f"\n[bold green]✓ Found {result['num_objects']} objects[/bold green]") | |
| if result['num_objects'] > 0: | |
| total_pixels = result['image_shape'][0] * result['image_shape'][1] | |
| console.print("\n[bold]Object details:[/bold]") | |
| sorted_masks = sorted(result['raw_masks'], key=lambda x: x['area'], reverse=True) | |
| for idx, mask in enumerate(sorted_masks[:10]): # Show top 10 | |
| area = mask['area'] | |
| coverage = (area / total_pixels) * 100 | |
| iou = mask.get('predicted_iou', 0) | |
| excluded = " [dim](excluded)[/dim]" if coverage > max_coverage else "" | |
| console.print(f" Object {idx+1}: {area:,} pixels ({coverage:.1f}% of image) - IoU: {iou:.3f}{excluded}") | |
| except ImportError as e: | |
| console.print(f"[bold red]✗ SAM not available[/bold red]") | |
| console.print(f" Install with: pip install 'hae[vision]'") | |
| sys.exit(1) | |
| except Exception as e: | |
| console.print(f"[bold red]✗ Segmentation failed: {e}[/bold red]") | |
| sys.exit(1) | |
| @vision.command('batch') | |
| @click.argument('input_dir', type=click.Path(exists=True)) | |
| @click.option('--output-dir', type=click.Path(), default='./output', | |
| help='Output directory for masks') | |
| @click.option('--model', type=click.Choice(['vit_b', 'vit_l', 'vit_h']), | |
| default='vit_b', help='SAM model size') | |
| @click.option('--checkpoint-dir', type=click.Path(), | |
| help='Directory containing SAM checkpoints') | |
| @click.option('--threshold', type=float, default=0.5, | |
| help='IoU threshold for mask filtering') | |
| @click.option('--points-per-side', type=int, default=32, | |
| help='Points per side for segmentation') | |
| @click.option('--combine', type=click.Choice(['union', 'largest']), default='union', | |
| help='Combine all masks or only largest') | |
| @click.option('--save-binary/--no-binary', default=True, | |
| help='Save binary masks') | |
| @click.option('--save-instance/--no-instance', default=False, | |
| help='Save instance masks') | |
| @click.option('--recursive/--no-recursive', default=False, | |
| help='Process subdirectories') | |
| @click.option('--limit', type=int, help='Limit number of images to process') | |
| def vision_batch(input_dir, output_dir, model, checkpoint_dir, threshold, points_per_side, | |
| combine, save_binary, save_instance, recursive, limit): | |
| """Batch process images with SAM segmentation. | |
| Fast GPU-optimized processing of multiple images. | |
| Examples: | |
| hae vision batch ./dataset | |
| hae vision batch ./dataset --limit 100 --points-per-side 16 | |
| hae vision batch ./dataset --save-instance --recursive | |
| """ | |
| from pathlib import Path | |
| from rich.progress import Progress, TextColumn, BarColumn, TaskProgressColumn, TimeRemainingColumn | |
| from .vision import batch_segment_images | |
| console.print("[bold cyan]🎯 SAM Batch Segmentation[/bold cyan]\n") | |
| try: | |
| input_path = Path(input_dir) | |
| output_path = Path(output_dir) | |
| console.print(f" Input: {input_path}") | |
| console.print(f" Output: {output_path}") | |
| console.print(f" Model: {model}") | |
| console.print(f" Threshold: {threshold}") | |
| console.print(f" Points per side: {points_per_side}") | |
| if limit: | |
| console.print(f" Limit: {limit} images") | |
| console.print() | |
| # Get progress bar columns | |
| with Progress( | |
| TextColumn("[progress.description]{task.description}"), | |
| BarColumn(), | |
| TaskProgressColumn(), | |
| TimeRemainingColumn(), | |
| console=console | |
| ) as progress: | |
| task = progress.add_task("Processing images...", total=None) | |
| stats = batch_segment_images( | |
| input_path, | |
| output_path, | |
| model_type=model, | |
| checkpoint_dir=Path(checkpoint_dir) if checkpoint_dir else None, | |
| points_per_side=points_per_side, | |
| threshold=threshold, | |
| combine=combine, | |
| save_binary=save_binary, | |
| save_instance=save_instance, | |
| recursive=recursive, | |
| limit=limit, | |
| ) | |
| # Print summary | |
| console.print() | |
| console.print(f"[bold green]✓ Batch processing complete[/bold green]") | |
| console.print(f" Total: {stats['total']}") | |
| console.print(f" Successful: {stats['successful']}") | |
| if stats['failed'] > 0: | |
| console.print(f" Failed: {stats['failed']}") | |
| console.print(f" Total objects: {stats['total_objects']}") | |
| console.print(f" Output: {output_path}") | |
| except ImportError: | |
| console.print(f"[bold red]✗ SAM not available[/bold red]") | |
| console.print(f" Install with: pip install 'hae[vision]'") | |
| sys.exit(1) | |
| except Exception as e: | |
| console.print(f"[bold red]✗ Batch processing failed: {e}[/bold red]") | |
| import traceback | |
| traceback.print_exc() | |
| sys.exit(1) | |
| @vision.command('analyze') | |
| @click.argument('image', type=click.Path(exists=True)) | |
| @click.option('--prompt', default='Describe all objects in this image in detail', | |
| help='Prompt for the vision model') | |
| @click.option('--model', default='llama3.2-vision:11b', | |
| help='Ollama vision model name') | |
| @click.option('--host', default='localhost', | |
| help='Ollama server host') | |
| @click.option('--port', type=int, default=11434, | |
| help='Ollama server port') | |
| def vision_analyze(image, prompt, model, host, port): | |
| """Analyze an image using Ollama vision models. | |
| Examples: | |
| hae vision analyze image.bmp | |
| hae vision analyze image.bmp --prompt "Describe any abnormalities" | |
| hae vision analyze image.bmp --model llava:13b | |
| """ | |
| from pathlib import Path | |
| from .vision import analyze_image_ollama | |
| console.print("[bold cyan]👁️ Ollama Vision Analysis[/bold cyan]\n") | |
| try: | |
| image_path = Path(image) | |
| console.print(f" Image: {image_path.name}") | |
| console.print(f" Model: {model}") | |
| console.print(f" Prompt: {prompt}\n") | |
| with console.status("[bold green]Analyzing image..."): | |
| result = analyze_image_ollama( | |
| image_path, | |
| prompt=prompt, | |
| model=model, | |
| host=host, | |
| port=port, | |
| ) | |
| if result['success']: | |
| console.print(f"[bold green]✓ Analysis complete[/bold green]\n") | |
| console.print(result['response']) | |
| else: | |
| console.print(f"[bold red]✗ Analysis failed[/bold red]") | |
| console.print(f" {result['error']}") | |
| sys.exit(1) | |
| except Exception as e: | |
| console.print(f"[bold red]✗ Analysis failed: {e}[/bold red]") | |
| sys.exit(1) | |
| @vision.command('download-models') | |
| @click.option('--model', type=click.Choice(['vit_b', 'vit_l', 'vit_h', 'all']), | |
| default='vit_b', help='Model to download (default: vit_b)') | |
| @click.option('--checkpoint-dir', type=click.Path(), | |
| help='Directory to save checkpoints (default: fixtures/models/)') | |
| @click.option('--force', is_flag=True, | |
| help='Force re-download even if checkpoint exists') | |
| @click.option('--no-verify', is_flag=True, | |
| help='Skip SHA256 verification (not recommended)') | |
| def vision_download_models(model, checkpoint_dir, force, no_verify): | |
| """Download SAM model checkpoints. | |
| Downloads SAM (Segment Anything Model) checkpoints to fixtures/models/ | |
| for DVC tracking and reproducibility. | |
| Model sizes: | |
| vit_b: 358 MB - Base model (fastest) | |
| vit_l: 1.2 GB - Large model (balanced) | |
| vit_h: 2.4 GB - Huge model (best quality) | |
| Examples: | |
| hae vision download-models | |
| hae vision download-models --model vit_l | |
| hae vision download-models --model all | |
| hae vision download-models --checkpoint-dir ~/.cache/sam | |
| """ | |
| from pathlib import Path | |
| from .vision import ( | |
| download_checkpoint, | |
| download_all_checkpoints, | |
| list_available_models, | |
| get_default_checkpoint_dir | |
| ) | |
| console.print("[bold cyan]📥 SAM Model Download[/bold cyan]\n") | |
| # Get target directory | |
| if checkpoint_dir: | |
| target_dir = Path(checkpoint_dir) | |
| else: | |
| target_dir = get_default_checkpoint_dir() | |
| # Check if in project for DVC instructions | |
| from .vision.download import get_project_root | |
| in_project = get_project_root() is not None | |
| # Show available models | |
| available_models = list_available_models() | |
| console.print("[bold]Available models:[/bold]") | |
| for model_type, info in available_models.items(): | |
| console.print(f" {model_type}: {info['size_mb']} MB") | |
| console.print(f"\n[bold]Download location:[/bold] {target_dir}") | |
| console.print() | |
| try: | |
| if model == 'all': | |
| console.print("[bold]Downloading all models...[/bold]\n") | |
| results = download_all_checkpoints( | |
| checkpoint_dir=target_dir if checkpoint_dir else None, | |
| force=force, | |
| verify=not no_verify | |
| ) | |
| # Show results | |
| console.print() | |
| for model_type, result in results.items(): | |
| if result['success']: | |
| status = result.get('status', 'downloaded') | |
| if status == 'already_exists': | |
| console.print(f"[green]✓[/green] {model_type}: Already exists") | |
| else: | |
| console.print(f"[green]✓[/green] {model_type}: Downloaded successfully") | |
| else: | |
| console.print(f"[red]✗[/red] {model_type}: {result.get('error', 'Failed')}") | |
| else: | |
| console.print(f"[bold]Downloading {model}...[/bold]\n") | |
| result = download_checkpoint( | |
| model_type=model, | |
| checkpoint_dir=target_dir if checkpoint_dir else None, | |
| force=force, | |
| verify=not no_verify | |
| ) | |
| if result['success']: | |
| console.print(f"\n[bold green]✓ Download complete[/bold green]") | |
| console.print(f" Path: {result['path']}") | |
| status = result.get('status', 'downloaded') | |
| if status == 'already_exists': | |
| console.print(f" Status: Already exists and verified") | |
| else: | |
| console.print(f" Status: Downloaded and verified") | |
| else: | |
| console.print(f"[bold red]✗ Download failed[/bold red]") | |
| console.print(f" {result.get('error', 'Unknown error')}") | |
| sys.exit(1) | |
| # Show DVC tracking instructions if in project | |
| if in_project and not checkpoint_dir: | |
| console.print("\n[bold cyan]📦 DVC Tracking[/bold cyan]") | |
| console.print("To track checkpoints with DVC, run:") | |
| console.print(" [bold]dvc add fixtures/models/[/bold]") | |
| console.print(" [bold]git add fixtures/models.dvc .gitignore[/bold]") | |
| console.print(" [bold]git commit -m 'Add SAM model checkpoints'[/bold]") | |
| console.print(" [bold]dvc push[/bold]") | |
| console.print("\nTeam members can then download with:") | |
| console.print(" [bold]dvc pull[/bold]") | |
| except Exception as e: | |
| console.print(f"[bold red]✗ Download failed: {e}[/bold red]") | |
| sys.exit(1) | |
| # ============================================================================ | |
| # Pipeline and Experiment Commands | |
| # ============================================================================ | |
| @main.group() | |
| def pipeline(): | |
| """Pipeline orchestration commands. | |
| Execute multi-stage workflows defined in YAML files. | |
| """ | |
| pass | |
| @pipeline.command('run') | |
| @click.argument('pipeline_id_or_path') | |
| @click.option('--vars', multiple=True, | |
| help='Variable overrides in key=value format (can be repeated)') | |
| @click.option('--dry-run', is_flag=True, | |
| help='Show commands without executing them') | |
| def pipeline_run(pipeline_id_or_path, vars, dry_run): | |
| """Run a pipeline by ID or YAML file path. | |
| Accepts either: | |
| - Pipeline ID (looks in definitions/pipelines/) | |
| - Full path to pipeline YAML file | |
| Execute multi-stage workflows with variable substitution and | |
| sequential stage execution. | |
| Examples: | |
| # By pipeline ID | |
| hae pipeline run smoke-test | |
| hae pipeline run eval-ad --vars model_path=./checkpoints/best.pt | |
| # By full path | |
| hae pipeline run definitions/pipelines/smoke-test.yaml | |
| hae pipeline run pipelines/smoke-test.yaml --vars model_size=UNet_S | |
| hae pipeline run pipelines/smoke-test.yaml --vars model_size=UNet_M --vars epochs=5 | |
| hae pipeline run pipelines/smoke-test.yaml --dry-run | |
| """ | |
| from pathlib import Path | |
| from .commands.pipeline import run_pipeline, parse_vars_overrides | |
| # Check if it's a path that exists | |
| input_path = Path(pipeline_id_or_path) | |
| if input_path.exists() and input_path.is_file(): | |
| yaml_path = input_path | |
| else: | |
| # Treat as pipeline ID - look in definitions directory | |
| definitions_dir = Path("definitions/pipelines") | |
| yaml_path = definitions_dir / f"{pipeline_id_or_path}.yaml" | |
| if not yaml_path.exists(): | |
| console.print(f"[red]❌ Pipeline not found: {pipeline_id_or_path}[/red]") | |
| console.print(f"\n[yellow]Tried:[/yellow]") | |
| console.print(f" - {yaml_path}") | |
| console.print(f"\n[cyan]Available pipelines:[/cyan]") | |
| # Show available pipelines | |
| from rich.table import Table | |
| table = Table(show_header=False, box=None) | |
| if definitions_dir.exists(): | |
| for pipeline_file in sorted(definitions_dir.glob("*.yaml")): | |
| pipeline_id = pipeline_file.stem | |
| table.add_row(f" • {pipeline_id}") | |
| console.print(table) | |
| console.print(f"\n[dim]Usage: hae pipeline run <id> or hae pipeline run <path>[/dim]") | |
| sys.exit(1) | |
| vars_dict = parse_vars_overrides(list(vars)) if vars else None | |
| success = run_pipeline(yaml_path, vars_dict, dry_run=dry_run) | |
| sys.exit(0 if success else 1) | |
| @pipeline.command('list') | |
| @click.option('--definitions', is_flag=True, help='List only pipelines from definitions/') | |
| def pipeline_list(definitions: bool): | |
| """List all available pipelines. | |
| By default, lists all pipelines from definitions/pipelines/*.yaml. | |
| Examples: | |
| hae pipeline list # List all pipelines | |
| hae pipeline list --definitions # List only definitions directory pipelines | |
| """ | |
| from pathlib import Path | |
| import yaml | |
| from rich.table import Table | |
| all_pipelines = [] | |
| definitions_dir = Path("definitions/pipelines") | |
| # List pipelines from definitions directory (default behavior) | |
| if definitions_dir.exists(): | |
| for yaml_file in sorted(definitions_dir.glob("*.yaml")): | |
| try: | |
| with open(yaml_file, encoding='utf-8') as f: | |
| config = yaml.safe_load(f) | |
| pipeline_id = yaml_file.stem | |
| name = config.get('name', pipeline_id) | |
| description = config.get('description', 'No description') | |
| # Get relative path | |
| try: | |
| rel_path = str(yaml_file.relative_to(Path.cwd())) | |
| except ValueError: | |
| rel_path = str(yaml_file) | |
| # Count stages | |
| stages = config.get('stages', []) | |
| num_stages = len(stages) if isinstance(stages, list) else 0 | |
| all_pipelines.append({ | |
| 'id': pipeline_id, | |
| 'name': name, | |
| 'description': description, | |
| 'path': rel_path, | |
| 'stages': num_stages | |
| }) | |
| except Exception as e: | |
| console.print(f"[yellow]Warning: Could not load {yaml_file}: {e}[/yellow]") | |
| if not all_pipelines: | |
| console.print("[yellow]No pipelines found[/yellow]") | |
| console.print(f"[dim]Looking in: {definitions_dir if definitions_dir.exists() else 'definitions/pipelines'}[/dim]") | |
| return | |
| # Display table | |
| table = Table(title="Available Pipelines", show_header=True, header_style="bold", box=None, padding=(0, 1)) | |
| table.add_column("ID", style="cyan", no_wrap=True, width=30) | |
| table.add_column("Name", style="green", no_wrap=True, width=40) | |
| table.add_column("Stages", style="yellow", width=8, justify="right") | |
| table.add_column("Description", style="dim", width=50) | |
| for pipeline in sorted(all_pipelines, key=lambda x: x['id']): | |
| # Truncate name and description | |
| name = pipeline['name'].replace('\n', ' ').replace('\r', ' ').strip() | |
| if len(name) > 38: | |
| name = name[:35] + "..." | |
| desc = pipeline['description'].replace('\n', ' ').replace('\r', ' ').strip() | |
| if len(desc) > 48: | |
| desc = desc[:45] + "..." | |
| pipeline_id = pipeline['id'] | |
| if len(pipeline_id) > 28: | |
| pipeline_id = pipeline_id[:25] + "..." | |
| table.add_row( | |
| pipeline_id, | |
| name, | |
| str(pipeline['stages']), | |
| desc | |
| ) | |
| console.print() | |
| console.print(table) | |
| console.print() | |
| @main.group() | |
| def experiment(): | |
| """Experiment management commands. | |
| Run parameter sweeps and manage experimental workflows. | |
| """ | |
| pass | |
| @experiment.command('list') | |
| @click.option('--tag', multiple=True, help='Filter by tag (can specify multiple)') | |
| @click.option('--exclude-tag', multiple=True, help='Exclude experiments with tag') | |
| @click.option('--search', help='Search in ID and description') | |
| @click.option('--numbered/--no-numbered', default=True, help='Show numbered list for selection') | |
| @click.option('--rebuild', is_flag=True, help='Rebuild database before listing') | |
| @click.option('--no-sync', is_flag=True, help='Skip auto-sync (use cached database)') | |
| def experiment_list(tag, exclude_tag, search, numbered, rebuild, no_sync): | |
| """List all available experiments from database. | |
| Uses the experiment database for fast filtering and search. | |
| Database is auto-synced incrementally unless --no-sync is specified. | |
| Examples: | |
| hae experiment list # Auto-syncs, then lists all | |
| hae experiment list --tag synthetic # Only synthetic experiments | |
| hae experiment list --tag training --tag carpet # Multiple tags (AND) | |
| hae experiment list --exclude-tag test # Exclude test experiments | |
| hae experiment list --search mvtec # Search in ID/description | |
| hae experiment list --rebuild # Full rebuild (slow) | |
| hae experiment list --no-sync # Skip sync (use cached data) | |
| """ | |
| from hae.db import ExperimentDB | |
| from hae.migration import sync_database_incremental | |
| import re | |
| from pathlib import Path | |
| # Initialize database | |
| db = ExperimentDB() | |
| # Sync strategy: rebuild, auto-sync, or skip | |
| if rebuild: | |
| # Full rebuild - clear and re-index everything | |
| console.print("[dim]Rebuilding experiment database...[/dim]") | |
| db.index_definitions() | |
| def_count = db.cursor.execute("SELECT COUNT(*) FROM experiment_definitions").fetchone()[0] | |
| console.print(f"[dim]Indexed {def_count} experiments[/dim]\n") | |
| elif not no_sync: | |
| # Smart auto-sync: only process changed/new files (fast!) | |
| new, updated, unchanged = sync_database_incremental(db) | |
| if new + updated > 0: | |
| console.print(f"[dim]↻ Synced: {new} new, {updated} updated[/dim]") | |
| # else: no_sync is True, skip sync entirely | |
| # Query experiments with filters | |
| experiments = db.query_definitions( | |
| tags=list(tag) if tag else None, | |
| exclude_tags=list(exclude_tag) if exclude_tag else None, | |
| search=search | |
| ) | |
| db.close() | |
| if not experiments: | |
| console.print("[yellow]No experiments found[/yellow]") | |
| if tag: | |
| console.print(f"[dim]Try removing tag filters: {', '.join(tag)}[/dim]") | |
| if search: | |
| console.print(f"[dim]Try removing search filter: {search}[/dim]") | |
| return | |
| # Build ID → number mapping for dependency resolution | |
| id_to_number = {exp.id: idx + 1 for idx, exp in enumerate(experiments)} | |
| # Parse dependencies from YAML files | |
| dependencies = {} | |
| for exp in experiments: | |
| yaml_path = Path(exp.yaml_path) | |
| try: | |
| with open(yaml_path, encoding='utf-8') as f: | |
| config = yaml.safe_load(f) | |
| deps = [] | |
| # Extract depends_on from all experiment definitions | |
| for exp_def in config.get('experiments', []): | |
| if 'depends_on' in exp_def: | |
| dep_list = exp_def['depends_on'] | |
| if isinstance(dep_list, str): | |
| dep_list = [dep_list] | |
| deps.extend(dep_list) | |
| # Deduplicate and resolve to numbers (only for experiments in current view) | |
| dep_ids = list(set(deps)) | |
| dep_numbers = [] | |
| for dep_id in dep_ids: | |
| if dep_id in id_to_number: | |
| dep_numbers.append(id_to_number[dep_id]) | |
| dependencies[exp.id] = sorted(dep_numbers) # Sort for consistency | |
| except Exception: | |
| dependencies[exp.id] = [] | |
| # Display table (compact format) | |
| table = Table(title="Available Experiments", show_header=True, header_style="bold", box=None, padding=(0, 1)) | |
| if numbered: | |
| table.add_column("#", style="dim", width=3, no_wrap=True) | |
| table.add_column("ID", style="cyan", no_wrap=True, width=28) | |
| table.add_column("Tags", style="yellow", width=25) | |
| table.add_column("Last Run", style="dim", width=16, no_wrap=True) | |
| table.add_column("Deps", style="magenta", width=12, no_wrap=True) | |
| for idx, exp in enumerate(experiments, start=1): | |
| # Format last run timestamp | |
| if exp.last_run_timestamp: | |
| # Parse ISO timestamp: 2025-11-15T23:21:12Z | |
| match = re.search(r'(\d{4}-\d{2}-\d{2})T(\d{2}:\d{2}):\d{2}', exp.last_run_timestamp) | |
| if match: | |
| last_run_str = f"{match.group(1)} {match.group(2)}" | |
| else: | |
| last_run_str = exp.last_run_timestamp[:16] | |
| else: | |
| last_run_str = "-" | |
| # Format tags (compact) | |
| if exp.tags: | |
| # Show first 3 tags, or indicate more | |
| tags_display = ', '.join(exp.tags[:3]) | |
| if len(exp.tags) > 3: | |
| tags_display += f" +{len(exp.tags) - 3}" | |
| # Truncate to fit column | |
| if len(tags_display) > 23: | |
| tags_display = tags_display[:20] + "..." | |
| else: | |
| tags_display = "-" | |
| # Conditional ID/Name display (Option 2: show name only when different) | |
| if exp.name == exp.id: | |
| id_display = exp.id | |
| else: | |
| id_display = f"{exp.id} ({exp.name})" | |
| # Truncate ID to fit column | |
| if len(id_display) > 26: | |
| id_display = id_display[:23] + "..." | |
| # Format dependencies (show numbers only) | |
| dep_numbers = dependencies.get(exp.id, []) | |
| if dep_numbers: | |
| # Limit to first 3 dependencies | |
| deps_display = ",".join(f"#{n}" for n in dep_numbers[:3]) | |
| if len(dep_numbers) > 3: | |
| deps_display += ",..." | |
| else: | |
| deps_display = "-" | |
| row_data = [id_display, tags_display, last_run_str, deps_display] | |
| if numbered: | |
| row_data.insert(0, str(idx)) | |
| table.add_row(*row_data) | |
| console.print(table) | |
| # Show summary with filter info | |
| summary_parts = [f"Total: {len(experiments)} experiments"] | |
| if tag: | |
| summary_parts.append(f"tags: {', '.join(tag)}") | |
| if exclude_tag: | |
| summary_parts.append(f"excluded: {', '.join(exclude_tag)}") | |
| if search: | |
| summary_parts.append(f"search: {search}") | |
| console.print(f"\n[dim]{' | '.join(summary_parts)}[/dim]") | |
| if numbered: | |
| console.print(f"[dim]Run: hae experiment run <#> or hae experiment run <id>[/dim]") | |
| else: | |
| console.print(f"[dim]Run: hae experiment run <id>[/dim]") | |
| @experiment.command('run') | |
| @click.argument('experiment_id_or_path') | |
| @click.option('--dry-run/--no-dry-run', default=False, help='Show execution plan without running') | |
| @click.option("-v", "--variables", multiple=True, help="Variables as key=value") | |
| def experiment_run(experiment_id_or_path, dry_run, variables): | |
| """Run an experiment by ID, number, fuzzy match, or YAML file path. | |
| Accepts: | |
| - Number from `hae experiment list`: hae experiment run 3 | |
| - Experiment ID: hae experiment run synthetic-carpet-train | |
| - Fuzzy match (substring): hae experiment run carpet | |
| - Full path: hae experiment run definitions/experiments/synthetic-carpet-train.yaml | |
| Supports multiple formats: | |
| 1. Orchestration format (definitions/experiments/*.yaml) - with `experiments:` key | |
| 2. Sweep format (definitions/experiments/*.yaml) - with `sweep:` and `cmd:` keys | |
| Examples: | |
| hae experiment run 3 # Run by number (from list) | |
| hae experiment run carpet # Fuzzy match | |
| hae experiment run synthetic-carpet-train # By exact ID | |
| hae experiment run definitions/experiments/synthetic-carpet-train.yaml # By path | |
| """ | |
| from pathlib import Path | |
| import yaml | |
| from .commands.experiment import run_experiment_sweep, run_experiment_orchestration | |
| from hae.db import ExperimentDB | |
| # Check if it's a path that exists | |
| input_path = Path(experiment_id_or_path) | |
| if input_path.exists() and input_path.is_file(): | |
| yaml_path = input_path | |
| else: | |
| # Try to resolve using database | |
| db = ExperimentDB() | |
| # Auto-rebuild if database is empty | |
| def_count = db.cursor.execute("SELECT COUNT(*) FROM experiment_definitions").fetchone()[0] | |
| if def_count == 0: | |
| console.print("[dim]Rebuilding experiment database...[/dim]") | |
| db.index_definitions() | |
| # Try numbered selection (e.g., "3") | |
| if experiment_id_or_path.isdigit(): | |
| num = int(experiment_id_or_path) | |
| all_experiments = db.query_definitions() | |
| if 1 <= num <= len(all_experiments): | |
| exp = all_experiments[num - 1] | |
| console.print(f"[dim]Selected #{num}: {exp.id}[/dim]") | |
| yaml_path = Path(exp.yaml_path) | |
| else: | |
| console.print(f"[red]❌ Number out of range: {num} (valid: 1-{len(all_experiments)})[/red]") | |
| console.print(f"[dim]Run 'hae experiment list' to see numbered experiments[/dim]") | |
| db.close() | |
| sys.exit(1) | |
| else: | |
| # Try exact ID match first | |
| definitions_dir = Path("definitions/experiments") | |
| yaml_path = definitions_dir / f"{experiment_id_or_path}.yaml" | |
| if not yaml_path.exists(): | |
| # Try fuzzy matching | |
| matches = db.query_definitions(search=experiment_id_or_path) | |
| if len(matches) == 0: | |
| console.print(f"[red]❌ No experiments found matching: {experiment_id_or_path}[/red]") | |
| console.print(f"\n[yellow]Tried:[/yellow]") | |
| console.print(f" - Exact ID: {experiment_id_or_path}") | |
| console.print(f" - Fuzzy search: {experiment_id_or_path}") | |
| console.print(f"\n[dim]Run 'hae experiment list' to see available experiments[/dim]") | |
| db.close() | |
| sys.exit(1) | |
| elif len(matches) == 1: | |
| exp = matches[0] | |
| console.print(f"[dim]Matched: {exp.id}[/dim]") | |
| yaml_path = Path(exp.yaml_path) | |
| else: | |
| # Multiple matches - show user | |
| console.print(f"[yellow]Multiple experiments match '{experiment_id_or_path}':[/yellow]\n") | |
| from rich.table import Table | |
| table = Table(show_header=True, box=None) | |
| table.add_column("#", style="dim", width=3) | |
| table.add_column("ID", style="cyan") | |
| table.add_column("Tags", style="yellow") | |
| for idx, exp in enumerate(matches, 1): | |
| tags_str = ', '.join(exp.tags[:2]) if exp.tags else "-" | |
| if exp.tags and len(exp.tags) > 2: | |
| tags_str += f" +{len(exp.tags) - 2}" | |
| table.add_row(str(idx), exp.id, tags_str) | |
| console.print(table) | |
| console.print(f"\n[dim]Please specify a more specific ID or use the number[/dim]") | |
| db.close() | |
| sys.exit(1) | |
| db.close() | |
| # Detect format: sweep+experiments (Format 3) > experiments (Format 1) > sweep+cmd (Format 2) | |
| try: | |
| with open(yaml_path, 'r', encoding='utf-8') as f: | |
| config = yaml.safe_load(f) | |
| except Exception as e: | |
| console.print(f"[red]Error loading YAML: {str(e)}[/red]") | |
| sys.exit(1) | |
| from .commands.pipeline import parse_vars_overrides | |
| var_overrides = {} | |
| if variables: | |
| var_overrides = parse_vars_overrides(list(variables)) | |
| if 'sweep' in config and 'experiments' in config: | |
| # Format 3: Sweep-of-Orchestrations (sweep at top level, runs entire workflow per trial) | |
| console.print("[cyan]Detected sweep-of-orchestrations format (sweep + experiments)[/cyan]") | |
| from .commands.experiment import run_experiment_sweep_orchestration | |
| success = run_experiment_sweep_orchestration(yaml_path, dry_run=dry_run, var_overrides=var_overrides) | |
| elif 'experiments' in config: | |
| # Format 1: Orchestration (multi-experiment workflow) | |
| console.print("[cyan]Detected orchestration experiment format[/cyan]") | |
| success = run_experiment_orchestration(yaml_path, dry_run=dry_run, var_overrides=var_overrides) | |
| elif 'sweep' in config and 'cmd' in config: | |
| # Format 2: Sweep (parameter sweep with single command) | |
| console.print("[cyan]Detected sweep experiment format[/cyan]") | |
| success = run_experiment_sweep(yaml_path, dry_run=dry_run, var_overrides=var_overrides) | |
| else: | |
| console.print("[red]Error: Unrecognized experiment format[/red]") | |
| console.print("[yellow]Expected one of:[/yellow]") | |
| console.print(" 1. 'experiments' key (orchestration format)") | |
| console.print(" 2. 'sweep' + 'cmd' keys (sweep format)") | |
| console.print(" 3. 'sweep' + 'experiments' keys (sweep-of-orchestrations format)") | |
| sys.exit(1) | |
| sys.exit(0 if success else 1) | |
| def _run_experiment_from_yaml(yaml_path: Path, dry_run: bool = False) -> bool: | |
| """Internal function to run an experiment from a YAML file. | |
| Args: | |
| yaml_path: Path to experiment YAML file | |
| dry_run: If True, show what would be executed without running | |
| Returns: | |
| True if experiment succeeded, False otherwise | |
| """ | |
| import yaml | |
| from .commands.experiment import run_experiment_sweep, run_experiment_orchestration, run_experiment_sweep_orchestration | |
| # Load config to detect format | |
| try: | |
| with open(yaml_path, 'r', encoding='utf-8') as f: | |
| config = yaml.safe_load(f) | |
| except Exception as e: | |
| console.print(f"[red]Error loading YAML: {str(e)}[/red]") | |
| return False | |
| # Detect format: sweep+experiments (Format 3) > experiments (Format 1) > sweep+cmd (Format 2) | |
| if 'sweep' in config and 'experiments' in config: | |
| # Format 3: Sweep-of-Orchestrations | |
| console.print("[cyan]Detected sweep-of-orchestrations format (sweep + experiments)[/cyan]") | |
| return run_experiment_sweep_orchestration(yaml_path, dry_run=dry_run) | |
| elif 'experiments' in config: | |
| # Format 1: Orchestration | |
| console.print("[cyan]Detected orchestration experiment format[/cyan]") | |
| return run_experiment_orchestration(yaml_path, dry_run=dry_run) | |
| elif 'sweep' in config and 'cmd' in config: | |
| # Format 2: Sweep | |
| console.print("[cyan]Detected sweep experiment format[/cyan]") | |
| if dry_run: | |
| console.print("[yellow]Dry-run not supported for sweep format[/yellow]") | |
| return run_experiment_sweep(yaml_path) | |
| else: | |
| console.print("[red]Error: Unrecognized experiment format[/red]") | |
| console.print("[yellow]Expected one of:[/yellow]") | |
| console.print(" 1. 'experiments' key (orchestration format)") | |
| console.print(" 2. 'sweep' + 'cmd' keys (sweep format)") | |
| console.print(" 3. 'sweep' + 'experiments' keys (sweep-of-orchestrations format)") | |
| return False | |
| @experiment.command('rerun') | |
| @click.argument('execution_path') | |
| @click.option('--execution', help='Select specific execution by index (1=newest, 2=2nd newest) or execution ID') | |
| @click.option('--dry-run/--no-dry-run', default=False, help='Show what would be executed without running') | |
| def experiment_rerun(execution_path, execution, dry_run): | |
| """Rerun an experiment from a previous execution record. | |
| You can rerun an experiment using: | |
| 1. Execution directory path | |
| 2. Experiment ID (finds most recent execution) | |
| 3. Snapshot YAML file path | |
| Use --execution to select a specific execution when using experiment ID: | |
| - By index: --execution 2 (2nd most recent) | |
| - By execution ID: --execution 232056_train-and-compare-evaluators or --execution 232056 | |
| Examples: | |
| # Rerun from execution directory | |
| hae experiment rerun experiment-history/2025/01/22/232056_train-and-compare-evaluators | |
| # Rerun from snapshot YAML | |
| hae experiment rerun experiment-history/2025/01/22/232056_train-and-compare-evaluators/attachments/20250122232056_train-and-compare-evaluators.yaml | |
| # Rerun most recent execution of an experiment (by ID) | |
| hae experiment rerun train-and-compare-evaluators | |
| # Rerun 2nd most recent execution | |
| hae experiment rerun train-and-compare-evaluators --execution 2 | |
| # Rerun specific execution by timestamp | |
| hae experiment rerun train-and-compare-evaluators --execution 232056 | |
| """ | |
| exec_path = Path(execution_path) | |
| # Case 1: It's a snapshot YAML file | |
| if exec_path.exists() and exec_path.is_file() and exec_path.suffix in ['.yaml', '.yml']: | |
| snapshot_yaml = exec_path | |
| console.print(f"[cyan]Rerunning from snapshot: {snapshot_yaml}[/cyan]\n") | |
| success = _run_experiment_from_yaml(snapshot_yaml, dry_run=dry_run) | |
| sys.exit(0 if success else 1) | |
| # Case 2: It's an execution directory | |
| elif exec_path.exists() and exec_path.is_dir(): | |
| # Look for snapshot YAML in attachments/ | |
| attachments_dir = exec_path / "attachments" | |
| if attachments_dir.exists(): | |
| # Find the experiment snapshot (usually named with timestamp_experiment-name.yaml) | |
| snapshot_files = list(attachments_dir.glob("*_*.yaml")) + list(attachments_dir.glob("*_*.yml")) | |
| if snapshot_files: | |
| # Use the most recent snapshot | |
| snapshot_yaml = max(snapshot_files, key=lambda p: p.stat().st_mtime) | |
| console.print(f"[cyan]Rerunning from execution record: {exec_path.name}[/cyan]") | |
| console.print(f"[dim]Using snapshot: {snapshot_yaml.name}[/dim]\n") | |
| success = _run_experiment_from_yaml(snapshot_yaml, dry_run=dry_run) | |
| sys.exit(0 if success else 1) | |
| else: | |
| console.print(f"[red]No snapshot YAML found in {attachments_dir}[/red]") | |
| sys.exit(1) | |
| else: | |
| console.print(f"[red]No attachments directory found in {exec_path}[/red]") | |
| sys.exit(1) | |
| # Case 3: It's an experiment ID - find execution(s) | |
| else: | |
| # Try to find execution(s) with this ID | |
| executions_dir = Path("experiment-history") | |
| if not executions_dir.exists(): | |
| console.print(f"[red]No history directory found[/red]") | |
| sys.exit(1) | |
| # Search for execution directories matching this ID | |
| matching_execs = [] | |
| exp_id = str(execution_path) # Use the original string | |
| for exec_dir in executions_dir.rglob("*"): | |
| if exec_dir.is_dir() and "_" in exec_dir.name: | |
| # Extract experiment name from directory (after first _) | |
| parts = exec_dir.name.split("_", 1) | |
| if len(parts) == 2: | |
| exec_name = parts[1] | |
| # Normalize names for comparison | |
| exec_name_norm = exec_name.replace("-", "_").replace(" ", "_") | |
| exp_id_norm = exp_id.replace("-", "_").replace(" ", "_") | |
| if exec_name_norm == exp_id_norm or exec_name == exp_id: | |
| matching_execs.append(exec_dir) | |
| if not matching_execs: | |
| console.print(f"[red]No execution found for experiment: {exp_id}[/red]") | |
| console.print(f"[yellow]Try: hae experiment list (to see available experiments)[/yellow]") | |
| sys.exit(1) | |
| # Sort by modification time (most recent first) | |
| matching_execs.sort(key=lambda p: p.stat().st_mtime, reverse=True) | |
| # Select specific execution if --execution provided | |
| if execution: | |
| # Try to parse as integer index (1-based) | |
| try: | |
| exec_index = int(execution) - 1 # Convert to 0-based | |
| # Check if it's a valid index range | |
| if 0 <= exec_index < len(matching_execs): | |
| # Valid index - use it | |
| selected_exec = matching_execs[exec_index] | |
| console.print(f"[cyan]Found {len(matching_execs)} execution(s) for '{exp_id}'[/cyan]") | |
| console.print(f"[cyan]Rerunning execution #{execution}: {selected_exec.name}[/cyan]\n") | |
| else: | |
| # Integer parsed but out of range - treat as pattern instead | |
| raise ValueError("Index out of range, try pattern matching") | |
| except ValueError: | |
| # Not a valid integer index, treat as execution ID pattern | |
| # Match execution directories containing this pattern | |
| filtered_execs = [] | |
| for exec_dir in matching_execs: | |
| # Check if execution pattern is in the directory name | |
| if execution in exec_dir.name or execution.replace("-", "_") in exec_dir.name: | |
| filtered_execs.append(exec_dir) | |
| if not filtered_execs: | |
| console.print(f"[red]No execution matching '{execution}' found for experiment '{exp_id}'[/red]") | |
| console.print(f"[yellow]Available executions:[/yellow]") | |
| for i, e in enumerate(matching_execs[:5], 1): | |
| console.print(f" {i}. {e.name}") | |
| if len(matching_execs) > 5: | |
| console.print(f" ... and {len(matching_execs) - 5} more") | |
| sys.exit(1) | |
| selected_exec = filtered_execs[0] # Use most recent matching | |
| console.print(f"[cyan]Found {len(filtered_execs)} execution(s) matching '{execution}'[/cyan]") | |
| console.print(f"[cyan]Rerunning: {selected_exec.name}[/cyan]\n") | |
| else: | |
| # No --execution flag: use most recent | |
| selected_exec = matching_execs[0] | |
| console.print(f"[cyan]Found {len(matching_execs)} execution(s) for '{exp_id}'[/cyan]") | |
| console.print(f"[cyan]Rerunning most recent: {selected_exec.name}[/cyan]\n") | |
| most_recent = selected_exec # Rename for compatibility with code below | |
| # Process the execution directory | |
| attachments_dir = most_recent / "attachments" | |
| if attachments_dir.exists(): | |
| # Find the experiment snapshot (prefer experiment YAML, fallback to any YAML) | |
| snapshot_files = list(attachments_dir.glob("*_*.yaml")) + list(attachments_dir.glob("*_*.yml")) | |
| if snapshot_files: | |
| # Prefer experiment snapshot (usually named with experiment name) | |
| experiment_snapshots = [f for f in snapshot_files if 'experiment' in f.stem.lower() or exp_id.replace('-', '_') in f.stem.lower()] | |
| if experiment_snapshots: | |
| snapshot_yaml = max(experiment_snapshots, key=lambda p: p.stat().st_mtime) | |
| else: | |
| # Fallback to most recent snapshot | |
| snapshot_yaml = max(snapshot_files, key=lambda p: p.stat().st_mtime) | |
| console.print(f"[dim]Using snapshot: {snapshot_yaml.name}[/dim]\n") | |
| success = _run_experiment_from_yaml(snapshot_yaml, dry_run=dry_run) | |
| sys.exit(0 if success else 1) | |
| else: | |
| console.print(f"[red]No snapshot YAML found in {attachments_dir}[/red]") | |
| sys.exit(1) | |
| else: | |
| console.print(f"[red]No attachments directory found in {most_recent}[/red]") | |
| sys.exit(1) | |
| @experiment.command('history') | |
| @click.argument('experiment_id', required=False) | |
| @click.option('--limit', type=int, default=10, help='Maximum number of executions to show') | |
| @click.option('--status', help='Filter by status (completed, failed, running)') | |
| @click.option('--rebuild', is_flag=True, help='Rebuild database before querying') | |
| @click.option('--no-sync', is_flag=True, help='Skip auto-sync (use cached database)') | |
| def experiment_history(experiment_id, limit, status, rebuild, no_sync): | |
| """List execution history for experiments. | |
| Shows execution history with timestamps, status, and paths. | |
| Database is auto-synced incrementally unless --no-sync is specified. | |
| Examples: | |
| hae experiment history # Auto-syncs, show all recent | |
| hae experiment history synthetic-carpet-train # Show specific experiment | |
| hae experiment history --status completed --limit 20 # Show recent completed runs | |
| hae experiment history --rebuild # Full rebuild (slow) | |
| hae experiment history --no-sync # Skip sync (use cached data) | |
| """ | |
| from hae.db import ExperimentDB | |
| from hae.migration import sync_database_incremental | |
| import re | |
| # Initialize database | |
| db = ExperimentDB() | |
| # Sync strategy: rebuild, auto-sync, or skip | |
| if rebuild: | |
| # Full rebuild - clear and re-index everything | |
| console.print("[dim]Rebuilding experiment database...[/dim]") | |
| db.index_definitions() | |
| db.index_executions() | |
| exec_count = db.cursor.execute("SELECT COUNT(*) FROM experiment_executions").fetchone()[0] | |
| console.print(f"[dim]Indexed {exec_count} executions[/dim]\n") | |
| elif not no_sync: | |
| # Smart auto-sync: only process changed/new files (fast!) | |
| new, updated, unchanged = sync_database_incremental(db) | |
| # Also sync executions (always incremental) | |
| db.index_executions() | |
| if new + updated > 0: | |
| console.print(f"[dim]↻ Synced: {new} new, {updated} updated definitions[/dim]") | |
| # else: no_sync is True, skip sync entirely | |
| # Query executions with filters | |
| executions = db.query_executions( | |
| definition_id=experiment_id, | |
| status=status, | |
| limit=limit | |
| ) | |
| db.close() | |
| if not executions: | |
| if experiment_id: | |
| console.print(f"[yellow]No executions found for experiment: {experiment_id}[/yellow]") | |
| else: | |
| console.print("[yellow]No executions found[/yellow]") | |
| return | |
| # Display table | |
| title = f"Executions for '{experiment_id}'" if experiment_id else "Recent Executions" | |
| table = Table(title=title) | |
| table.add_column("#", style="dim", width=4) | |
| table.add_column("Execution ID", style="cyan", no_wrap=True, max_width=28) | |
| table.add_column("Status", style="yellow", width=10) | |
| table.add_column("Timestamp", style="dim", width=16) | |
| table.add_column("Path", style="dim", max_width=35, overflow="ellipsis") | |
| for idx, exec in enumerate(executions, 1): | |
| # Status with icon | |
| status_icon = "✅" if exec.status == "completed" else "❌" if exec.status == "failed" else "⏳" | |
| status_str = f"{status_icon} {exec.status}" | |
| # Format timestamp (ISO format: 2025-11-15T23:21:12Z) | |
| timestamp_str = exec.timestamp | |
| match = re.search(r'(\d{4}-\d{2}-\d{2})T(\d{2}:\d{2}):\d{2}', exec.timestamp) | |
| if match: | |
| timestamp_str = f"{match.group(1)} {match.group(2)}" | |
| else: | |
| timestamp_str = exec.timestamp[:16] | |
| # Format execution ID (truncate if needed) | |
| exec_id = exec.id | |
| if len(exec_id) > 26: | |
| exec_id = exec_id[:23] + "..." | |
| # Format path (relative) | |
| from pathlib import Path | |
| try: | |
| rel_path = str(Path(exec.execution_path).relative_to(Path.cwd())) | |
| except ValueError: | |
| rel_path = exec.execution_path | |
| if len(rel_path) > 33: | |
| rel_path = "..." + rel_path[-30:] | |
| table.add_row( | |
| str(idx), | |
| exec_id, | |
| status_str, | |
| timestamp_str, | |
| rel_path | |
| ) | |
| console.print(table) | |
| # Summary | |
| summary_parts = [f"Total: {len(executions)} execution(s)"] | |
| if experiment_id: | |
| summary_parts.append(f"experiment: {experiment_id}") | |
| if status: | |
| summary_parts.append(f"status: {status}") | |
| console.print(f"\n[dim]{' | '.join(summary_parts)}[/dim]") | |
| console.print(f"[dim]Rerun: hae experiment rerun <execution_id> or hae experiment rerun <path>[/dim]") | |
| @experiment.command('show') | |
| @click.argument('experiment_id') | |
| @click.option('--execution', help='Select specific execution by index (1=newest) or execution ID') | |
| def experiment_show(experiment_id, execution): | |
| """Show detailed information about a specific execution. | |
| Examples: | |
| hae experiment show train-and-compare-evaluators # Show newest execution | |
| hae experiment show train-and-compare-evaluators --execution 2 # Show 2nd newest | |
| hae experiment show train-and-compare-evaluators --execution 232056 | |
| """ | |
| import re | |
| executions_dir = Path("experiment-history") | |
| if not executions_dir.exists(): | |
| console.print("[yellow]No history directory found[/yellow]") | |
| return | |
| # Find all executions matching this experiment ID | |
| matching_execs = [] | |
| exp_id_norm = experiment_id.replace("-", "_").replace(" ", "_") | |
| for exec_dir in executions_dir.rglob("*"): | |
| if exec_dir.is_dir() and "_" in exec_dir.name: | |
| parts = exec_dir.name.split("_", 1) | |
| if len(parts) == 2: | |
| exec_name = parts[1] | |
| exec_name_norm = exec_name.replace("-", "_").replace(" ", "_") | |
| if exec_name_norm == exp_id_norm or exec_name == experiment_id: | |
| matching_execs.append(exec_dir) | |
| if not matching_execs: | |
| console.print(f"[yellow]No executions found for experiment: {experiment_id}[/yellow]") | |
| return | |
| # Sort by modification time (newest first) | |
| matching_execs.sort(key=lambda p: p.stat().st_mtime, reverse=True) | |
| # Select specific execution if --execution provided, otherwise use newest | |
| if execution: | |
| # Try to parse as integer index (1-based) | |
| try: | |
| exec_index = int(execution) - 1 # Convert to 0-based | |
| if 0 <= exec_index < len(matching_execs): | |
| selected_exec = matching_execs[exec_index] | |
| else: | |
| raise ValueError("Index out of range, try pattern matching") | |
| except ValueError: | |
| # Not a valid integer index, treat as execution ID pattern | |
| filtered_execs = [] | |
| for exec_dir in matching_execs: | |
| if execution in exec_dir.name or execution.replace("-", "_") in exec_dir.name: | |
| filtered_execs.append(exec_dir) | |
| if not filtered_execs: | |
| console.print(f"[red]No execution found matching: {execution}[/red]") | |
| console.print(f"[dim]Found {len(matching_execs)} execution(s) for '{experiment_id}'[/dim]") | |
| console.print(f"[dim]Try: hae experiment history {experiment_id}[/dim]") | |
| sys.exit(1) | |
| if len(filtered_execs) > 1: | |
| console.print(f"[yellow]Multiple executions match '{execution}':[/yellow]") | |
| for i, e in enumerate(filtered_execs, 1): | |
| console.print(f" {i}. {e.name}") | |
| console.print(f"[dim]Use --execution <number> to select a specific one[/dim]") | |
| sys.exit(1) | |
| selected_exec = filtered_execs[0] | |
| else: | |
| # No execution specified, use newest | |
| selected_exec = matching_execs[0] | |
| # Read experiment.org file | |
| org_file = selected_exec / "experiment.org" | |
| if not org_file.exists(): | |
| console.print(f"[red]experiment.org not found in {selected_exec}[/red]") | |
| return | |
| with open(org_file) as f: | |
| org_content = f.read() | |
| # Parse metadata from org file | |
| def extract_property(name): | |
| match = re.search(rf':{name}: (.+)', org_content) | |
| return match.group(1) if match else "N/A" | |
| # Extract basic info | |
| exp_name = extract_property("EXPERIMENT_NAME") | |
| status = extract_property("STATUS") | |
| started = extract_property("STARTED") | |
| completed = extract_property("COMPLETED") | |
| duration = extract_property("DURATION") | |
| # Extract system info | |
| git_commit = extract_property("GIT_COMMIT") | |
| git_branch = extract_property("GIT_BRANCH") | |
| git_dirty = extract_property("GIT_DIRTY") | |
| python_version = extract_property("PYTHON_VERSION") | |
| cuda_version = extract_property("CUDA_VERSION") | |
| pytorch_version = extract_property("PYTORCH_VERSION") | |
| # Display information | |
| console.print(f"\n[bold cyan]Execution: {selected_exec.name}[/bold cyan]") | |
| console.print(f"[bold]Experiment:[/bold] {exp_name}") | |
| console.print(f"[bold]Status:[/bold] {status}") | |
| console.print(f"[bold]Started:[/bold] {started}") | |
| console.print(f"[bold]Completed:[/bold] {completed}") | |
| console.print(f"[bold]Duration:[/bold] {duration}") | |
| console.print(f"\n[bold]System Information:[/bold]") | |
| console.print(f" Git Commit: {git_commit}") | |
| console.print(f" Git Branch: {git_branch}") | |
| console.print(f" Git Dirty: {git_dirty}") | |
| console.print(f" Python: {python_version}") | |
| console.print(f" PyTorch: {pytorch_version}") | |
| console.print(f" CUDA: {cuda_version}") | |
| console.print(f"\n[bold]Path:[/bold] {selected_exec}") | |
| console.print(f"[dim]View full details: cat {org_file}[/dim]") | |
| @experiment.command('status') | |
| @click.argument('experiment_id') | |
| def experiment_status(experiment_id): | |
| """Show detailed status for an experiment. | |
| Shows all execution history, including running, completed, and failed experiments. | |
| Example: | |
| hae experiment status train-and-compare-evaluators | |
| """ | |
| from .commands.experiment_zettelkasten import is_experiment_running | |
| import re | |
| from rich.table import Table | |
| # Find executions for this experiment | |
| executions_dir = Path("experiment-history") | |
| if not executions_dir.exists(): | |
| console.print(f"[yellow]No experiment history found[/yellow]") | |
| sys.exit(1) | |
| # Find matching executions | |
| matching_execs = [] | |
| for exec_dir in executions_dir.rglob("*"): | |
| if exec_dir.is_dir() and "_" in exec_dir.name: | |
| parts = exec_dir.name.split("_", 1) | |
| if len(parts) == 2: | |
| exec_name = parts[1] | |
| exec_name_norm = exec_name.replace("-", "_").replace(" ", "_") | |
| exp_id_norm = experiment_id.replace("-", "_").replace(" ", "_") | |
| if exec_name_norm == exp_id_norm or exec_name == experiment_id: | |
| matching_execs.append(exec_dir) | |
| if not matching_execs: | |
| console.print(f"[yellow]No executions found for experiment: {experiment_id}[/yellow]") | |
| sys.exit(1) | |
| # Sort by modification time (most recent first) | |
| matching_execs.sort(key=lambda p: p.stat().st_mtime, reverse=True) | |
| # Count statuses | |
| running_count = 0 | |
| succeeded_count = 0 | |
| failed_count = 0 | |
| unknown_count = 0 | |
| # Collect execution details | |
| exec_details = [] | |
| for exec_dir in matching_execs: | |
| try: | |
| # Check if running | |
| if is_experiment_running(exec_dir): | |
| status = "running" | |
| running_count += 1 | |
| started = "N/A" | |
| completed = "N/A" | |
| else: | |
| # Read status from experiment.org | |
| org_file = exec_dir / "experiment.org" | |
| if org_file.exists(): | |
| with open(org_file) as f: | |
| content = f.read() | |
| status_match = re.search(r':STATUS: (\w+)', content) | |
| if status_match: | |
| status = status_match.group(1) | |
| if status == "completed": | |
| succeeded_count += 1 | |
| elif status == "failed": | |
| failed_count += 1 | |
| else: | |
| unknown_count += 1 | |
| else: | |
| status = "unknown" | |
| unknown_count += 1 | |
| # Extract timing info | |
| started_match = re.search(r':STARTED: \[([^\]]+)\]', content) | |
| completed_match = re.search(r':COMPLETED: \[([^\]]+)\]', content) | |
| started = started_match.group(1) if started_match else "N/A" | |
| completed = completed_match.group(1) if completed_match else "N/A" | |
| else: | |
| status = "unknown" | |
| unknown_count += 1 | |
| started = "N/A" | |
| completed = "N/A" | |
| # Get relative path | |
| try: | |
| rel_path = str(exec_dir.relative_to(Path.cwd())) | |
| except ValueError: | |
| rel_path = str(exec_dir) | |
| exec_details.append({ | |
| 'path': rel_path, | |
| 'status': status, | |
| 'started': started, | |
| 'completed': completed | |
| }) | |
| except Exception as e: | |
| console.print(f"[yellow]Warning: Error reading {exec_dir}: {e}[/yellow]") | |
| # Display summary | |
| total = len(matching_execs) | |
| console.print(f"\n[bold cyan]Experiment: {experiment_id}[/bold cyan]") | |
| console.print(f"\n[bold]Summary:[/bold]") | |
| console.print(f" Total executions: {total}") | |
| console.print(f" Running: [yellow]{running_count}[/yellow]") | |
| console.print(f" Succeeded: [green]{succeeded_count}[/green]") | |
| console.print(f" Failed: [red]{failed_count}[/red]") | |
| if unknown_count > 0: | |
| console.print(f" Unknown: [dim]{unknown_count}[/dim]") | |
| # Display table of recent executions | |
| console.print(f"\n[bold]Recent Executions:[/bold]") | |
| table = Table(show_header=True, header_style="bold", box=None, padding=(0, 1)) | |
| table.add_column("Status", style="yellow", width=12) | |
| table.add_column("Started", style="cyan", width=22) | |
| table.add_column("Completed", style="green", width=22) | |
| table.add_column("Path", style="dim") | |
| # Show up to 10 most recent | |
| for exec_info in exec_details[:10]: | |
| status = exec_info['status'] | |
| # Color code status | |
| if status == "running": | |
| status_str = "[yellow]⚙ running[/yellow]" | |
| elif status == "completed": | |
| status_str = "[green]✓ completed[/green]" | |
| elif status == "failed": | |
| status_str = "[red]✗ failed[/red]" | |
| else: | |
| status_str = "[dim]? unknown[/dim]" | |
| table.add_row( | |
| status_str, | |
| exec_info['started'], | |
| exec_info['completed'], | |
| exec_info['path'] | |
| ) | |
| console.print(table) | |
| if len(exec_details) > 10: | |
| console.print(f"\n[dim]... and {len(exec_details) - 10} more executions[/dim]") | |
| console.print(f"\n[dim]View execution: cat <path>/experiment.org[/dim]") | |
| @main.group() | |
| def dataset(): | |
| """Dataset management commands. | |
| List and inspect dataset statistics, splits, and configurations. | |
| """ | |
| pass | |
| @dataset.command('list') | |
| @click.option('--dataset', type=str, help='Filter by dataset type (mvtec, visa, etc.)') | |
| @click.option('--object-category', type=str, help='Filter by object category') | |
| @click.option('--data-dir', type=str, help='Data directory path (default: ./datasets/raw/MVTec/mvtec_ad)') | |
| @click.option('--csv-split-file', type=str, help='CSV split file path') | |
| def datasets_list(dataset, object_category, data_dir, csv_split_file): | |
| """List available datasets with statistics and split information. | |
| Shows dataset statistics including: | |
| - Dataset type and object categories | |
| - Split counts (train/val/test) | |
| - CSV split file information | |
| - Data directory paths | |
| Examples: | |
| hae datasets list # List all datasets | |
| hae datasets list --dataset mvtec # List MVTec datasets only | |
| hae datasets list --object-category bottle # Show bottle category stats | |
| hae datasets list --data-dir ./my-data # Use custom data directory | |
| """ | |
| from pathlib import Path | |
| import sys | |
| # Import dataset utilities | |
| sys.path.insert(0, str(Path(__file__).parent.parent.parent / "dioodmi-py" / "src")) | |
| try: | |
| from dioodmi.data.dataset_configs import BASE_DATASET_REGISTRY, get_base_dataset_config | |
| from dioodmi.data.MVTECDataLoader import MVTECDataset | |
| from dioodmi.data.anomaly_core import DatasetConfig, ProcessingParams | |
| except ImportError as e: | |
| console.print(f"[red]Error importing dataset modules: {e}[/red]") | |
| console.print("[yellow]Make sure dioodmi-py is properly installed[/yellow]") | |
| sys.exit(1) | |
| # Determine which datasets to show | |
| if dataset: | |
| if dataset not in BASE_DATASET_REGISTRY: | |
| console.print(f"[red]Unknown dataset: {dataset}[/red]") | |
| console.print(f"[yellow]Available datasets: {', '.join(BASE_DATASET_REGISTRY.keys())}[/yellow]") | |
| sys.exit(1) | |
| datasets_to_check = [dataset] | |
| else: | |
| datasets_to_check = list(BASE_DATASET_REGISTRY.keys()) | |
| # Default data directory | |
| if not data_dir: | |
| data_dir = './datasets/raw/MVTec/mvtec_ad' | |
| all_dataset_stats = [] | |
| for dataset_name in sorted(datasets_to_check): | |
| base_config = BASE_DATASET_REGISTRY[dataset_name] | |
| # Determine object categories to check | |
| if object_category: | |
| if object_category not in base_config.object_classes: | |
| continue # Skip this dataset if category not found | |
| categories_to_check = [object_category] | |
| else: | |
| categories_to_check = list(base_config.object_classes.keys()) | |
| for category in categories_to_check: | |
| # Try to load dataset for each split to get counts | |
| stats = { | |
| 'dataset': dataset_name, | |
| 'category': category, | |
| 'train_count': None, | |
| 'val_count': None, | |
| 'test_count': None, | |
| 'data_dir': data_dir, | |
| 'csv_split_file': csv_split_file or base_config.csv_split_file, | |
| 'error': None | |
| } | |
| # Try to get counts for each split | |
| for split in ['train', 'val', 'test']: | |
| try: | |
| # Create dataset config | |
| config = DatasetConfig( | |
| name=base_config.name, | |
| object_classes=base_config.object_classes, | |
| csv_columns=base_config.csv_columns, | |
| default_image_size=base_config.default_image_size, | |
| default_crop_size=base_config.default_crop_size, | |
| csv_split_file=csv_split_file if csv_split_file else base_config.csv_split_file, | |
| strategies=base_config.strategies, | |
| temporal_strategies=base_config.temporal_strategies, | |
| mode=split, | |
| object_class=category, | |
| rootdir=data_dir, | |
| anomaly_class='good' if split == 'train' else 'all' | |
| ) | |
| # Create processing params | |
| params = ProcessingParams( | |
| image_size=base_config.default_image_size, | |
| crop_size=base_config.default_crop_size | |
| ) | |
| # Try to load dataset (without actually loading images) | |
| from dioodmi.data.anomaly_core import parse_csv_data | |
| data_df = parse_csv_data(config, params) | |
| count = len(data_df) | |
| stats[f'{split}_count'] = count | |
| except Exception as e: | |
| # Dataset might not exist or split might not be available | |
| stats['error'] = str(e) | |
| continue | |
| all_dataset_stats.append(stats) | |
| if not all_dataset_stats: | |
| console.print("[yellow]No datasets found matching criteria[/yellow]") | |
| return | |
| # Display table | |
| table = Table(title="Dataset Statistics", show_header=True, header_style="bold", box=None, padding=(0, 1)) | |
| table.add_column("Dataset", style="cyan", no_wrap=True, width=15) | |
| table.add_column("Category", style="green", no_wrap=True, width=15) | |
| table.add_column("Train", style="yellow", justify="right", width=8) | |
| table.add_column("Val", style="yellow", justify="right", width=8) | |
| table.add_column("Test", style="yellow", justify="right", width=8) | |
| table.add_column("CSV Split", style="dim", width=20) | |
| for stats in all_dataset_stats: | |
| train_str = str(stats['train_count']) if stats['train_count'] is not None else "-" | |
| val_str = str(stats['val_count']) if stats['val_count'] is not None else "-" | |
| test_str = str(stats['test_count']) if stats['test_count'] is not None else "-" | |
| csv_file = stats['csv_split_file'] | |
| if len(csv_file) > 18: | |
| csv_file = csv_file[:15] + "..." | |
| table.add_row( | |
| stats['dataset'], | |
| stats['category'], | |
| train_str, | |
| val_str, | |
| test_str, | |
| csv_file | |
| ) | |
| console.print(table) | |
| # Show summary | |
| total_datasets = len(set(s['dataset'] for s in all_dataset_stats)) | |
| total_categories = len(all_dataset_stats) | |
| console.print(f"\n[dim]Total: {total_categories} category/dataset combinations ({total_datasets} unique datasets)[/dim]") | |
| if data_dir: | |
| console.print(f"[dim]Data directory: {data_dir}[/dim]") | |
| @dataset.group() | |
| def split(): | |
| """Dataset split management commands. | |
| Create and manage dataset split CSV files. | |
| """ | |
| pass | |
| @split.command('create') | |
| @click.option('--dataset', type=str, help='Dataset type (mvtec, visa, etc.)') | |
| @click.option('--object-category', type=str, help='Object category to process') | |
| @click.option('--data-dir', type=str, required=True, help='Data directory path') | |
| @click.option('--output-csv', type=str, required=True, | |
| help='Output CSV file path (default: definitions/splits/ if relative)') | |
| @click.option('--strategy', type=click.Choice(['static', 'pool']), default='static', | |
| help='Split strategy: static (preserve directory structure) or pool (all samples as pool)') | |
| @click.option('--include-index', is_flag=True, help='Include index column in CSV') | |
| @click.option('--no-masks', is_flag=True, help='Do not search for mask files') | |
| def split_create(dataset, object_category, data_dir, output_csv, strategy, include_index, no_masks): | |
| """Create split CSV file from dataset directory structure. | |
| Discovers files from directory structure and generates CSV split file. | |
| Supports both static splits (preserve directory structure) and pool mode | |
| (all samples marked as 'pool' for dynamic splitting). | |
| Examples: | |
| # Create pool CSV (all samples as 'pool') | |
| hae dataset split create \\ | |
| --data-dir ./datasets/raw/MVTec/mvtec_ad \\ | |
| --object-category bottle \\ | |
| --output-csv splits/mvtec-bottle-pool.csv \\ | |
| --strategy pool | |
| # Create static split CSV (preserve directory structure) | |
| hae dataset split create \\ | |
| --data-dir ./datasets/raw/MVTec/mvtec_ad \\ | |
| --object-category bottle \\ | |
| --output-csv splits/mvtec-bottle-static.csv \\ | |
| --strategy static | |
| """ | |
| from pathlib import Path | |
| import sys | |
| import csv | |
| # Import file discovery utilities | |
| sys.path.insert(0, str(Path(__file__).parent.parent.parent / "dioodmi-py" / "src")) | |
| try: | |
| from dioodmi.data.file_discovery import ( | |
| discover_files_from_directory, | |
| detect_structure | |
| ) | |
| except ImportError as e: | |
| console.print(f"[red]Error importing file discovery modules: {e}[/red]") | |
| console.print("[yellow]Make sure dioodmi-py is properly installed[/yellow]") | |
| sys.exit(1) | |
| # Validate data directory | |
| data_dir_path = Path(data_dir) | |
| if not data_dir_path.exists(): | |
| console.print(f"[red]Data directory not found: {data_dir}[/red]") | |
| sys.exit(1) | |
| # Determine object categories to process | |
| if object_category: | |
| object_categories = [object_category] | |
| else: | |
| # Auto-detect all object categories | |
| object_categories = [] | |
| for item in sorted(data_dir_path.iterdir()): | |
| if item.is_dir() and not item.name.startswith('.'): | |
| detected_structure = detect_structure(item) | |
| if detected_structure != "unknown": | |
| object_categories.append(item.name) | |
| if not object_categories: | |
| console.print(f"[red]No valid object categories found in {data_dir}[/red]") | |
| console.print("[yellow]Please specify --object-category explicitly[/yellow]") | |
| sys.exit(1) | |
| # Process each object category | |
| all_rows = [] | |
| current_idx = 0 | |
| console.print(f"[bold]Creating split CSV file[/bold]") | |
| console.print(f" Data directory: {data_dir}") | |
| console.print(f" Output CSV: {output_csv}") | |
| console.print(f" Strategy: {strategy}") | |
| console.print(f" Object categories: {', '.join(object_categories)}\n") | |
| for obj_cat in object_categories: | |
| try: | |
| df = discover_files_from_directory( | |
| data_dir=data_dir, | |
| object_category=obj_cat, | |
| structure=None, # Auto-detect | |
| include_masks=not no_masks, | |
| split_map=None, | |
| pool_mode=(strategy == 'pool') | |
| ) | |
| # Convert DataFrame to rows | |
| for _, row in df.iterrows(): | |
| row_dict = row.to_dict() | |
| if include_index: | |
| row_dict[''] = current_idx | |
| current_idx += 1 | |
| all_rows.append(row_dict) | |
| # Show summary for this category | |
| train_count = len(df[df['split'] == 'train']) if 'train' in df['split'].values else 0 | |
| test_count = len(df[df['split'] == 'test']) if 'test' in df['split'].values else 0 | |
| val_count = len(df[df['split'] == 'val']) if 'val' in df['split'].values else 0 | |
| pool_count = len(df[df['split'] == 'pool']) if 'pool' in df['split'].values else 0 | |
| console.print(f" [{obj_cat}] Found {len(df)} images") | |
| if strategy == 'pool': | |
| console.print(f" All marked as 'pool' for dynamic splitting") | |
| else: | |
| if train_count > 0: | |
| console.print(f" Train: {train_count}") | |
| if val_count > 0: | |
| console.print(f" Val: {val_count}") | |
| if test_count > 0: | |
| console.print(f" Test: {test_count}") | |
| except Exception as e: | |
| console.print(f"[red]Error processing {obj_cat}: {e}[/red]") | |
| continue | |
| if not all_rows: | |
| console.print("[red]No data found to write to CSV[/red]") | |
| sys.exit(1) | |
| # Write CSV file | |
| output_path = Path(output_csv) | |
| # If relative path and doesn't exist, default to definitions/splits/ | |
| if not output_path.is_absolute() and not output_path.parent.exists(): | |
| # Check if it's just a filename | |
| if output_path.parent == Path('.'): | |
| output_path = Path('definitions/splits') / output_path.name | |
| output_path.parent.mkdir(parents=True, exist_ok=True) | |
| # Determine fieldnames | |
| fieldnames = [] | |
| if include_index: | |
| fieldnames.append("") | |
| fieldnames.extend(["object", "split", "label", "image", "mask", "category"]) | |
| with open(output_path, "w", newline="", encoding="utf-8") as f: | |
| writer = csv.DictWriter(f, fieldnames=fieldnames) | |
| writer.writeheader() | |
| writer.writerows(all_rows) | |
| console.print(f"\n[green]✓ Created split CSV: {output_path}[/green]") | |
| console.print(f" Total entries: {len(all_rows)}") | |
| @split.command('add') | |
| @click.option('--csv', 'csv_file', type=str, required=True, help='Existing CSV file to update') | |
| @click.option('--from-dir', type=str, help='Discover files from directory') | |
| @click.option('--from-csv', type=str, help='Add samples from another CSV file') | |
| @click.option('--object-category', type=str, help='Only add samples for this object category') | |
| @click.option('--split', type=click.Choice(['train', 'val', 'test', 'pool']), help='Override split assignment (for directory input)') | |
| @click.option('--pool-mode', is_flag=True, help='Mark all new samples as pool (for directory input)') | |
| @click.option('--duplicate-strategy', type=click.Choice(['skip', 'error']), default='error', | |
| help='How to handle duplicates: skip (keep existing) or error (fail)') | |
| @click.option('--output-csv', type=str, help='Write to different file (default: update in-place)') | |
| @click.option('--backup', is_flag=True, help='Create backup before updating') | |
| @click.option('--dry-run', is_flag=True, help='Show what would be added without modifying file') | |
| @click.option('--structure', type=click.Choice(['ok_ng', 'category']), help='Force structure type (default: auto-detect)') | |
| @click.option('--no-masks', is_flag=True, help='Do not search for mask files') | |
| @click.option('--include-index', is_flag=True, help='Include index column in output') | |
| def split_add(csv_file, from_dir, from_csv, object_category, split, pool_mode, duplicate_strategy, | |
| output_csv, backup, dry_run, structure, no_masks, include_index): | |
| """Add new samples to an existing split CSV file. | |
| Supports adding samples from directory structure or from another CSV file. | |
| Only adds NEW samples - use 'update' command to modify existing entries. | |
| Examples: | |
| # Add new samples from directory | |
| hae dataset split add \\ | |
| --csv splits/mvtec-bottle-pool.csv \\ | |
| --from-dir ./datasets/raw/MVTec/mvtec_ad/bottle \\ | |
| --object-category bottle \\ | |
| --split pool \\ | |
| --duplicate-strategy skip | |
| # Merge from another CSV | |
| hae dataset split add \\ | |
| --csv splits/mvtec-bottle-full.csv \\ | |
| --from-csv splits/mvtec-bottle-val-only.csv \\ | |
| --duplicate-strategy skip | |
| """ | |
| from pathlib import Path | |
| import sys | |
| # Import utilities | |
| sys.path.insert(0, str(Path(__file__).parent.parent.parent / "dioodmi-py" / "src")) | |
| try: | |
| from dioodmi.data.split_utils import ( | |
| load_split_csv, save_split_csv, resolve_split_csv_path, | |
| find_duplicates, create_backup | |
| ) | |
| from dioodmi.data.file_discovery import discover_files_from_directory | |
| except ImportError as e: | |
| console.print(f"[red]Error importing modules: {e}[/red]") | |
| sys.exit(1) | |
| # Validate input source | |
| if not from_dir and not from_csv: | |
| console.print("[red]Error: Must specify either --from-dir or --from-csv[/red]") | |
| sys.exit(1) | |
| if from_dir and from_csv: | |
| console.print("[red]Error: Cannot specify both --from-dir and --from-csv[/red]") | |
| sys.exit(1) | |
| # Load existing CSV | |
| try: | |
| existing_df = load_split_csv(csv_file) | |
| console.print(f"[dim]Loaded existing CSV: {len(existing_df)} samples[/dim]") | |
| except Exception as e: | |
| console.print(f"[red]Error loading CSV: {e}[/red]") | |
| sys.exit(1) | |
| # Discover or load new samples | |
| if from_dir: | |
| # Discover from directory | |
| if not object_category: | |
| console.print("[red]Error: --object-category required when using --from-dir[/red]") | |
| sys.exit(1) | |
| try: | |
| new_df = discover_files_from_directory( | |
| data_dir=from_dir, | |
| object_category=object_category, | |
| structure=structure, | |
| include_masks=not no_masks, | |
| split_map=None, | |
| pool_mode=pool_mode or (split == 'pool') | |
| ) | |
| # Override split if specified | |
| if split and not pool_mode: | |
| new_df['split'] = split | |
| console.print(f"[dim]Discovered {len(new_df)} new samples from directory[/dim]") | |
| except Exception as e: | |
| console.print(f"[red]Error discovering files: {e}[/red]") | |
| sys.exit(1) | |
| else: # from_csv | |
| # Load from another CSV | |
| try: | |
| new_df = load_split_csv(from_csv) | |
| console.print(f"[dim]Loaded {len(new_df)} samples from source CSV[/dim]") | |
| # Override split if specified | |
| if split: | |
| new_df['split'] = split | |
| # Filter by object category if specified | |
| if object_category: | |
| new_df = new_df[new_df['object'] == object_category].copy() | |
| console.print(f"[dim]Filtered to {len(new_df)} samples for object '{object_category}'[/dim]") | |
| except Exception as e: | |
| console.print(f"[red]Error loading source CSV: {e}[/red]") | |
| sys.exit(1) | |
| if len(new_df) == 0: | |
| console.print("[yellow]No new samples to add[/yellow]") | |
| return | |
| # Check for duplicates | |
| existing_keys = set(zip(existing_df['object'], existing_df['image'])) | |
| new_keys = set(zip(new_df['object'], new_df['image'])) | |
| duplicates = existing_keys & new_keys | |
| if duplicates: | |
| if duplicate_strategy == 'error': | |
| console.print(f"[red]Error: Found {len(duplicates)} duplicate entries[/red]") | |
| console.print("[yellow]Use --duplicate-strategy skip to ignore duplicates[/yellow]") | |
| sys.exit(1) | |
| else: # skip | |
| console.print(f"[yellow]Found {len(duplicates)} duplicate entries, skipping[/yellow]") | |
| # Filter out duplicates | |
| new_df = new_df[~new_df.apply(lambda row: (row['object'], row['image']) in duplicates, axis=1)].copy() | |
| console.print(f"[dim]Adding {len(new_df)} new samples (skipped {len(duplicates)} duplicates)[/dim]") | |
| if len(new_df) == 0: | |
| console.print("[yellow]No new samples to add after filtering duplicates[/yellow]") | |
| return | |
| # Combine DataFrames | |
| combined_df = pd.concat([existing_df, new_df], ignore_index=True) | |
| # Statistics | |
| stats = { | |
| 'existing': len(existing_df), | |
| 'new': len(new_df), | |
| 'duplicates_skipped': len(duplicates) if duplicates else 0, | |
| 'total': len(combined_df) | |
| } | |
| if dry_run: | |
| console.print("\n[bold]Dry run - no changes made[/bold]") | |
| console.print(f" Existing samples: {stats['existing']}") | |
| console.print(f" New samples to add: {stats['new']}") | |
| if stats['duplicates_skipped'] > 0: | |
| console.print(f" Duplicates skipped: {stats['duplicates_skipped']}") | |
| console.print(f" Total after add: {stats['total']}") | |
| return | |
| # Create backup if requested | |
| if backup: | |
| backup_path = create_backup(csv_file) | |
| console.print(f"[dim]Created backup: {backup_path}[/dim]") | |
| # Determine output path | |
| output_path = Path(output_csv) if output_csv else resolve_split_csv_path(csv_file) | |
| # Save | |
| try: | |
| save_split_csv(combined_df, str(output_path), include_index=include_index) | |
| console.print(f"\n[green]✓ Added {stats['new']} samples to CSV[/green]") | |
| console.print(f" Total samples: {stats['total']}") | |
| if stats['duplicates_skipped'] > 0: | |
| console.print(f" Duplicates skipped: {stats['duplicates_skipped']}") | |
| console.print(f" Saved to: {output_path}") | |
| except Exception as e: | |
| console.print(f"[red]Error saving CSV: {e}[/red]") | |
| sys.exit(1) | |
| @split.command('list') | |
| @click.option('--splits-dir', type=str, default='./definitions/splits/', | |
| help='Directory to search for split CSV files (default: definitions/splits/)') | |
| @click.option('--all', 'list_all', is_flag=True, help='List from both definitions/splits/ and fixtures/splits/') | |
| @click.option('--dataset', type=str, help='Filter by dataset type (e.g., mvtec, sdp, visa)') | |
| @click.option('--object-category', type=str, help='Filter by object category') | |
| @click.option('--split-type', type=click.Choice(['train', 'val', 'test', 'pool']), help='Filter by split type') | |
| @click.option('--pattern', type=str, help='Filter by filename pattern (e.g., *pool*.csv)') | |
| @click.option('--format', type=click.Choice(['table', 'json']), default='table', help='Output format') | |
| @click.option('--detailed', is_flag=True, help='Show detailed statistics') | |
| @click.option('--validate', is_flag=True, help='Validate CSV files') | |
| @click.option('--sort-by', type=click.Choice(['name', 'size', 'date', 'samples']), default='name', help='Sort results') | |
| def split_list(splits_dir, list_all, dataset, object_category, split_type, pattern, format, detailed, validate, sort_by): | |
| """List all available split CSV files with statistics. | |
| Examples: | |
| # List all split files | |
| hae dataset split list | |
| # List MVTec splits only | |
| hae dataset split list --dataset mvtec | |
| # Show detailed statistics | |
| hae dataset split list --detailed | |
| # Validate all splits | |
| hae dataset split list --validate | |
| """ | |
| from pathlib import Path | |
| import sys | |
| import glob | |
| from datetime import datetime | |
| # Import utilities | |
| sys.path.insert(0, str(Path(__file__).parent.parent.parent / "dioodmi-py" / "src")) | |
| try: | |
| from dioodmi.data.split_utils import ( | |
| load_split_csv, compute_statistics, compute_per_object_stats, | |
| validate_structure, resolve_split_csv_path | |
| ) | |
| except ImportError as e: | |
| console.print(f"[red]Error importing modules: {e}[/red]") | |
| sys.exit(1) | |
| # Determine directories to search | |
| if list_all: | |
| # Search both locations | |
| search_dirs = [ | |
| Path('definitions/splits'), | |
| Path('fixtures/splits') | |
| ] | |
| else: | |
| # Search specified directory | |
| search_dirs = [Path(splits_dir)] | |
| # Find all CSV files from all search directories | |
| csv_files = [] | |
| for search_dir in search_dirs: | |
| if not search_dir.exists(): | |
| if not list_all: # Only warn if not using --all | |
| console.print(f"[yellow]Splits directory not found: {search_dir}[/yellow]") | |
| continue | |
| if pattern: | |
| csv_files.extend(search_dir.glob(pattern)) | |
| else: | |
| csv_files.extend(search_dir.glob("*.csv")) | |
| # Remove duplicates (in case both directories have same files) | |
| csv_files = list(set(csv_files)) | |
| if not csv_files: | |
| if list_all: | |
| console.print(f"[yellow]No CSV files found in definitions/splits/ or fixtures/splits/[/yellow]") | |
| else: | |
| console.print(f"[yellow]No CSV files found in {splits_dir}[/yellow]") | |
| return | |
| # Analyze each file | |
| results = [] | |
| for csv_file in sorted(csv_files): | |
| try: | |
| df = load_split_csv(str(csv_file)) | |
| stats = compute_statistics(df) | |
| # Apply filters | |
| if dataset and dataset.lower() not in csv_file.stem.lower(): | |
| continue | |
| if object_category and object_category not in stats['objects']: | |
| continue | |
| if split_type and split_type not in stats['splits']: | |
| continue | |
| # Get file metadata | |
| file_stat = csv_file.stat() | |
| result = { | |
| 'filename': csv_file.name, | |
| 'path': str(csv_file), | |
| 'file_size': file_stat.st_size, | |
| 'modified': datetime.fromtimestamp(file_stat.st_mtime), | |
| 'total_samples': stats['total_samples'], | |
| 'objects': stats['objects'], | |
| 'splits': stats['splits'], | |
| 'labels': stats['labels'], | |
| 'categories': stats['categories'], | |
| 'mask_coverage': stats['mask_coverage'], | |
| } | |
| if detailed: | |
| result['per_object'] = compute_per_object_stats(df) | |
| if validate: | |
| validation = validate_structure(df) | |
| result['validation'] = validation | |
| results.append(result) | |
| except Exception as e: | |
| console.print(f"[yellow]Error processing {csv_file.name}: {e}[/yellow]") | |
| continue | |
| if not results: | |
| console.print("[yellow]No matching split files found[/yellow]") | |
| return | |
| # Sort results | |
| if sort_by == 'name': | |
| results.sort(key=lambda x: x['filename']) | |
| elif sort_by == 'size': | |
| results.sort(key=lambda x: x['file_size'], reverse=True) | |
| elif sort_by == 'date': | |
| results.sort(key=lambda x: x['modified'], reverse=True) | |
| elif sort_by == 'samples': | |
| results.sort(key=lambda x: x['total_samples'], reverse=True) | |
| # Format output | |
| if format == 'json': | |
| import json | |
| console.print(json.dumps(results, indent=2, default=str)) | |
| else: # table | |
| from rich.table import Table | |
| table = Table(title="Split CSV Files", show_header=True, header_style="bold magenta") | |
| table.add_column("Filename", style="cyan") | |
| table.add_column("Samples", justify="right") | |
| table.add_column("Objects", justify="right") | |
| table.add_column("Splits", style="green") | |
| table.add_column("Labels", style="yellow") | |
| if validate: | |
| table.add_column("Validation", style="red") | |
| for result in results: | |
| # Format splits | |
| splits_str = ", ".join(f"{k}:{v}" for k, v in sorted(result['splits'].items()) if v > 0) | |
| # Format labels | |
| labels_str = ", ".join(f"{k}:{v}" for k, v in sorted(result['labels'].items())) | |
| # Format objects | |
| objects_str = str(len(result['objects'])) if not detailed else ", ".join(result['objects'][:3]) | |
| if detailed and len(result['objects']) > 3: | |
| objects_str += f" (+{len(result['objects'])-3})" | |
| row = [ | |
| result['filename'], | |
| str(result['total_samples']), | |
| objects_str, | |
| splits_str, | |
| labels_str, | |
| ] | |
| if validate: | |
| val_status = "✓" if result.get('validation', {}).get('is_valid', True) else "✗" | |
| row.append(val_status) | |
| table.add_row(*row) | |
| console.print(table) | |
| if detailed: | |
| console.print("\n[bold]Detailed Statistics:[/bold]") | |
| for result in results: | |
| console.print(f"\n[cyan]{result['filename']}[/cyan]") | |
| if 'per_object' in result: | |
| for obj, obj_stats in result['per_object'].items(): | |
| console.print(f" {obj}: {obj_stats['total']} samples") | |
| console.print(f" Splits: {obj_stats['splits']}") | |
| console.print(f" Labels: {obj_stats['labels']}") | |
| @split.command('stats') | |
| @click.option('--csv', 'csv_file', type=str, required=True, help='CSV file to analyze') | |
| @click.option('--format', type=click.Choice(['table', 'json']), default='table', help='Output format') | |
| @click.option('--per-object', is_flag=True, help='Show per-object breakdown') | |
| @click.option('--per-category', is_flag=True, help='Show per-category breakdown') | |
| @click.option('--per-split', is_flag=True, help='Show per-split breakdown') | |
| def split_stats(csv_file, format, per_object, per_category, per_split): | |
| """Show detailed statistics for a single split CSV file. | |
| Examples: | |
| # Basic stats | |
| hae dataset split stats --csv splits/mvtec-bottle-pool.csv | |
| # Detailed per-object breakdown | |
| hae dataset split stats --csv splits/mvtec-split.csv --per-object | |
| """ | |
| from pathlib import Path | |
| import sys | |
| # Import utilities | |
| sys.path.insert(0, str(Path(__file__).parent.parent.parent / "dioodmi-py" / "src")) | |
| try: | |
| from dioodmi.data.split_utils import ( | |
| load_split_csv, compute_statistics, compute_per_object_stats, | |
| compute_per_category_stats, compute_per_split_stats, resolve_split_csv_path | |
| ) | |
| except ImportError as e: | |
| console.print(f"[red]Error importing modules: {e}[/red]") | |
| sys.exit(1) | |
| # Load CSV | |
| try: | |
| df = load_split_csv(csv_file) | |
| except Exception as e: | |
| console.print(f"[red]Error loading CSV: {e}[/red]") | |
| sys.exit(1) | |
| # Compute statistics | |
| stats = compute_statistics(df) | |
| result = { | |
| 'filename': Path(csv_file).name, | |
| 'path': str(resolve_split_csv_path(csv_file)), | |
| 'total_samples': stats['total_samples'], | |
| 'objects': stats['objects'], | |
| 'splits': stats['splits'], | |
| 'labels': stats['labels'], | |
| 'categories': stats['categories'], | |
| 'mask_coverage': stats['mask_coverage'], | |
| } | |
| if per_object: | |
| result['per_object'] = compute_per_object_stats(df) | |
| if per_category: | |
| result['per_category'] = compute_per_category_stats(df) | |
| if per_split: | |
| result['per_split'] = compute_per_split_stats(df) | |
| # Format output | |
| if format == 'json': | |
| import json | |
| console.print(json.dumps(result, indent=2, default=str)) | |
| else: # table | |
| console.print(f"\n[bold]Statistics for: {result['filename']}[/bold]") | |
| console.print(f" Total samples: {result['total_samples']}") | |
| console.print(f" Objects: {', '.join(result['objects'])}") | |
| console.print(f" Mask coverage: {result['mask_coverage']:.1%}") | |
| console.print(f"\n[bold]Splits:[/bold]") | |
| for split, count in sorted(result['splits'].items()): | |
| console.print(f" {split}: {count}") | |
| console.print(f"\n[bold]Labels:[/bold]") | |
| for label, count in sorted(result['labels'].items()): | |
| console.print(f" {label}: {count}") | |
| if per_object: | |
| console.print(f"\n[bold]Per-Object Breakdown:[/bold]") | |
| for obj, obj_stats in result['per_object'].items(): | |
| console.print(f" {obj}: {obj_stats['total']} samples") | |
| console.print(f" Splits: {obj_stats['splits']}") | |
| console.print(f" Labels: {obj_stats['labels']}") | |
| if per_category: | |
| console.print(f"\n[bold]Per-Category Breakdown:[/bold]") | |
| for cat, cat_stats in sorted(result['per_category'].items()): | |
| console.print(f" {cat}: {cat_stats['total']} samples") | |
| console.print(f" Splits: {cat_stats['splits']}") | |
| if per_split: | |
| console.print(f"\n[bold]Per-Split Breakdown:[/bold]") | |
| for split, split_stats in sorted(result['per_split'].items()): | |
| console.print(f" {split}: {split_stats['total']} samples") | |
| console.print(f" Objects: {split_stats['objects']}") | |
| @split.command('update') | |
| @click.option('--csv', 'csv_file', type=str, required=True, help='CSV file to update') | |
| @click.option('--filter-object', type=str, help='Filter by object category') | |
| @click.option('--filter-split', type=str, help='Filter by split') | |
| @click.option('--filter-label', type=str, help='Filter by label') | |
| @click.option('--filter-category', type=str, help='Filter by category') | |
| @click.option('--filter-image', type=str, help='Filter by image path') | |
| @click.option('--set-split', type=click.Choice(['train', 'val', 'test', 'pool']), help='Set new split value') | |
| @click.option('--set-label', type=click.Choice(['normal', 'anomaly']), help='Set new label value') | |
| @click.option('--set-category', type=str, help='Set new category value') | |
| @click.option('--output-csv', type=str, help='Write to different file (default: update in-place)') | |
| @click.option('--backup', is_flag=True, help='Create backup before updating') | |
| @click.option('--dry-run', is_flag=True, help='Show what would be updated without modifying file') | |
| @click.option('--interactive', is_flag=True, help='Interactive mode: confirm each update') | |
| def split_update(csv_file, filter_object, filter_split, filter_label, filter_category, filter_image, | |
| set_split, set_label, set_category, output_csv, backup, dry_run, interactive): | |
| """Update existing entries in split CSV file. | |
| Modifies EXISTING samples only - use 'add' command to add new samples. | |
| Requires at least one filter and one set operation. | |
| Examples: | |
| # Move all pool samples to train | |
| hae dataset split update \\ | |
| --csv splits/mvtec-bottle-pool.csv \\ | |
| --filter-split pool \\ | |
| --set-split train | |
| # Update specific object category | |
| hae dataset split update \\ | |
| --csv splits/mvtec-split.csv \\ | |
| --filter-object bottle \\ | |
| --filter-split pool \\ | |
| --set-split train | |
| """ | |
| from pathlib import Path | |
| import sys | |
| # Import utilities | |
| sys.path.insert(0, str(Path(__file__).parent.parent.parent / "dioodmi-py" / "src")) | |
| try: | |
| from dioodmi.data.split_utils import ( | |
| load_split_csv, save_split_csv, resolve_split_csv_path, | |
| apply_filters, create_backup | |
| ) | |
| except ImportError as e: | |
| console.print(f"[red]Error importing modules: {e}[/red]") | |
| sys.exit(1) | |
| # Validate filters and set operations | |
| filters = { | |
| 'filter_object': filter_object, | |
| 'filter_split': filter_split, | |
| 'filter_label': filter_label, | |
| 'filter_category': filter_category, | |
| 'filter_image': filter_image, | |
| } | |
| filters = {k: v for k, v in filters.items() if v is not None} | |
| set_ops = { | |
| 'set_split': set_split, | |
| 'set_label': set_label, | |
| 'set_category': set_category, | |
| } | |
| set_ops = {k: v for k, v in set_ops.items() if v is not None} | |
| if not filters: | |
| console.print("[red]Error: At least one filter is required[/red]") | |
| sys.exit(1) | |
| if not set_ops: | |
| console.print("[red]Error: At least one set operation is required[/red]") | |
| sys.exit(1) | |
| # Load CSV | |
| try: | |
| df = load_split_csv(csv_file) | |
| console.print(f"[dim]Loaded CSV: {len(df)} samples[/dim]") | |
| except Exception as e: | |
| console.print(f"[red]Error loading CSV: {e}[/red]") | |
| sys.exit(1) | |
| # Apply filters | |
| filtered_df = apply_filters(df, **filters) | |
| if len(filtered_df) == 0: | |
| console.print("[yellow]No samples match the filter criteria[/yellow]") | |
| return | |
| console.print(f"[dim]Found {len(filtered_df)} samples matching filters[/dim]") | |
| # Apply updates | |
| updated_df = df.copy() | |
| update_mask = apply_filters(updated_df, **filters).index | |
| if 'set_split' in set_ops: | |
| updated_df.loc[update_mask, 'split'] = set_ops['set_split'] | |
| if 'set_label' in set_ops: | |
| updated_df.loc[update_mask, 'label'] = set_ops['set_label'] | |
| if 'set_category' in set_ops: | |
| updated_df.loc[update_mask, 'category'] = set_ops['set_category'] | |
| # Show what will be updated | |
| if dry_run or interactive: | |
| console.print("\n[bold]Samples to be updated:[/bold]") | |
| for idx in update_mask[:10]: # Show first 10 | |
| row = updated_df.loc[idx] | |
| console.print(f" {row['object']}/{row['image']}") | |
| if 'set_split' in set_ops: | |
| console.print(f" split: {df.loc[idx, 'split']} → {set_ops['set_split']}") | |
| if 'set_label' in set_ops: | |
| console.print(f" label: {df.loc[idx, 'label']} → {set_ops['set_label']}") | |
| if 'set_category' in set_ops: | |
| console.print(f" category: {df.loc[idx, 'category']} → {set_ops['set_category']}") | |
| if len(update_mask) > 10: | |
| console.print(f" ... and {len(update_mask) - 10} more") | |
| if dry_run: | |
| console.print("\n[bold]Dry run - no changes made[/bold]") | |
| return | |
| if interactive: | |
| if not click.confirm(f"\nUpdate {len(update_mask)} samples?"): | |
| console.print("[yellow]Cancelled[/yellow]") | |
| return | |
| # Create backup if requested | |
| if backup: | |
| backup_path = create_backup(csv_file) | |
| console.print(f"[dim]Created backup: {backup_path}[/dim]") | |
| # Determine output path | |
| output_path = Path(output_csv) if output_csv else resolve_split_csv_path(csv_file) | |
| # Save | |
| try: | |
| save_split_csv(updated_df, str(output_path)) | |
| console.print(f"\n[green]✓ Updated {len(update_mask)} samples[/green]") | |
| console.print(f" Saved to: {output_path}") | |
| except Exception as e: | |
| console.print(f"[red]Error saving CSV: {e}[/red]") | |
| sys.exit(1) | |
| @split.command('remove') | |
| @click.option('--csv', 'csv_file', type=str, required=True, help='CSV file to update') | |
| @click.option('--filter-object', type=str, help='Filter by object category') | |
| @click.option('--filter-split', type=str, help='Filter by split') | |
| @click.option('--filter-label', type=str, help='Filter by label') | |
| @click.option('--filter-category', type=str, help='Filter by category') | |
| @click.option('--filter-image', type=str, help='Filter by image path (exact match or pattern)') | |
| @click.option('--from-csv', type=str, help='Remove samples matching entries in another CSV') | |
| @click.option('--output-csv', type=str, help='Write to different file (default: update in-place)') | |
| @click.option('--backup', is_flag=True, help='Create backup before removing') | |
| @click.option('--dry-run', is_flag=True, help='Show what would be removed without modifying file') | |
| @click.option('--interactive', is_flag=True, help='Interactive mode: confirm removal') | |
| def split_remove(csv_file, filter_object, filter_split, filter_label, filter_category, filter_image, | |
| from_csv, output_csv, backup, dry_run, interactive): | |
| """Remove samples from split CSV file. | |
| Examples: | |
| # Remove specific image | |
| hae dataset split remove \\ | |
| --csv splits/mvtec-split.csv \\ | |
| --filter-image bottle/train/good/001.png | |
| # Remove all samples from specific split | |
| hae dataset split remove \\ | |
| --csv splits/mvtec-split.csv \\ | |
| --filter-split pool | |
| """ | |
| from pathlib import Path | |
| import sys | |
| # Import utilities | |
| sys.path.insert(0, str(Path(__file__).parent.parent.parent / "dioodmi-py" / "src")) | |
| try: | |
| from dioodmi.data.split_utils import ( | |
| load_split_csv, save_split_csv, resolve_split_csv_path, | |
| apply_filters, create_backup | |
| ) | |
| except ImportError as e: | |
| console.print(f"[red]Error importing modules: {e}[/red]") | |
| sys.exit(1) | |
| # Load CSV | |
| try: | |
| df = load_split_csv(csv_file) | |
| console.print(f"[dim]Loaded CSV: {len(df)} samples[/dim]") | |
| except Exception as e: | |
| console.print(f"[red]Error loading CSV: {e}[/red]") | |
| sys.exit(1) | |
| # Determine what to remove | |
| if from_csv: | |
| # Remove samples matching entries in another CSV | |
| try: | |
| remove_df = load_split_csv(from_csv) | |
| remove_keys = set(zip(remove_df['object'], remove_df['image'])) | |
| remove_mask = df.apply(lambda row: (row['object'], row['image']) in remove_keys, axis=1) | |
| except Exception as e: | |
| console.print(f"[red]Error loading source CSV: {e}[/red]") | |
| sys.exit(1) | |
| else: | |
| # Apply filters | |
| filters = { | |
| 'filter_object': filter_object, | |
| 'filter_split': filter_split, | |
| 'filter_label': filter_label, | |
| 'filter_category': filter_category, | |
| 'filter_image': filter_image, | |
| } | |
| filters = {k: v for k, v in filters.items() if v is not None} | |
| if not filters: | |
| console.print("[red]Error: At least one filter or --from-csv is required[/red]") | |
| sys.exit(1) | |
| filtered_df = apply_filters(df, **filters) | |
| remove_mask = df.index.isin(filtered_df.index) | |
| if not remove_mask.any(): | |
| console.print("[yellow]No samples match the removal criteria[/yellow]") | |
| return | |
| num_to_remove = remove_mask.sum() | |
| console.print(f"[dim]Found {num_to_remove} samples to remove[/dim]") | |
| if num_to_remove == len(df): | |
| console.print("[yellow]Warning: This will remove ALL samples from the CSV[/yellow]") | |
| if not interactive and not dry_run: | |
| if not click.confirm("Continue?"): | |
| console.print("[yellow]Cancelled[/yellow]") | |
| return | |
| # Show what will be removed | |
| if dry_run or interactive: | |
| console.print("\n[bold]Samples to be removed:[/bold]") | |
| to_remove = df[remove_mask] | |
| for idx, row in to_remove.head(10).iterrows(): | |
| console.print(f" {row['object']}/{row['image']} ({row['split']}, {row['label']})") | |
| if num_to_remove > 10: | |
| console.print(f" ... and {num_to_remove - 10} more") | |
| if dry_run: | |
| console.print("\n[bold]Dry run - no changes made[/bold]") | |
| return | |
| if interactive: | |
| if not click.confirm(f"\nRemove {num_to_remove} samples?"): | |
| console.print("[yellow]Cancelled[/yellow]") | |
| return | |
| # Create backup if requested | |
| if backup: | |
| backup_path = create_backup(csv_file) | |
| console.print(f"[dim]Created backup: {backup_path}[/dim]") | |
| # Remove samples | |
| updated_df = df[~remove_mask].copy() | |
| # Determine output path | |
| output_path = Path(output_csv) if output_csv else resolve_split_csv_path(csv_file) | |
| # Save | |
| try: | |
| save_split_csv(updated_df, str(output_path)) | |
| console.print(f"\n[green]✓ Removed {num_to_remove} samples[/green]") | |
| console.print(f" Remaining samples: {len(updated_df)}") | |
| console.print(f" Saved to: {output_path}") | |
| except Exception as e: | |
| console.print(f"[red]Error saving CSV: {e}[/red]") | |
| sys.exit(1) | |
| @split.command('validate') | |
| @click.option('--csv', 'csv_file', type=str, required=True, help='CSV file to validate') | |
| @click.option('--check-duplicates', is_flag=True, help='Check for duplicate entries') | |
| @click.option('--check-paths', is_flag=True, help='Check path formats') | |
| @click.option('--check-structure', is_flag=True, help='Check CSV structure') | |
| @click.option('--fix', is_flag=True, help='Auto-fix issues where possible') | |
| @click.option('--output-csv', type=str, help='Write fixed CSV to different file') | |
| @click.option('--output-report', type=str, help='Write validation report to file') | |
| @click.option('--format', type=click.Choice(['table', 'json']), default='table', help='Output format') | |
| def split_validate(csv_file, check_duplicates, check_paths, check_structure, fix, | |
| output_csv, output_report, format): | |
| """Validate split CSV file structure. | |
| Examples: | |
| # Full validation | |
| hae dataset split validate \\ | |
| --csv splits/mvtec-split.csv \\ | |
| --check-duplicates --check-paths --check-structure | |
| # Auto-fix issues | |
| hae dataset split validate \\ | |
| --csv splits/mvtec-split.csv \\ | |
| --fix \\ | |
| --output-csv splits/mvtec-split-fixed.csv | |
| """ | |
| from pathlib import Path | |
| import sys | |
| # Import utilities | |
| sys.path.insert(0, str(Path(__file__).parent.parent.parent / "dioodmi-py" / "src")) | |
| try: | |
| from dioodmi.data.split_utils import ( | |
| load_split_csv, save_split_csv, validate_structure, | |
| find_duplicates, remove_duplicates | |
| ) | |
| except ImportError as e: | |
| console.print(f"[red]Error importing modules: {e}[/red]") | |
| sys.exit(1) | |
| # Load CSV (this already validates schema) | |
| try: | |
| df = load_split_csv(csv_file) | |
| except Exception as e: | |
| console.print(f"[red]Error loading CSV: {e}[/red]") | |
| sys.exit(1) | |
| # Run validations | |
| results = { | |
| 'is_valid': True, | |
| 'errors': [], | |
| 'warnings': [], | |
| 'fixes_applied': [], | |
| } | |
| # Structure validation (always run) | |
| if check_structure or not (check_duplicates or check_paths): | |
| structure = validate_structure(df) | |
| results['is_valid'] = structure['is_valid'] | |
| results['errors'].extend(structure.get('errors', [])) | |
| results['warnings'].extend(structure.get('warnings', [])) | |
| # Duplicate check | |
| if check_duplicates: | |
| duplicates = find_duplicates(df) | |
| if len(duplicates) > 0: | |
| results['is_valid'] = False | |
| results['warnings'].append(f"Found {len(duplicates)} duplicate entries") | |
| if fix: | |
| df = remove_duplicates(df, strategy='keep-first') | |
| results['fixes_applied'].append(f"Removed {len(duplicates)} duplicate entries") | |
| # Path check | |
| if check_paths: | |
| invalid_paths = df['image'].isnull().sum() + (df['image'] == '').sum() | |
| if invalid_paths > 0: | |
| results['is_valid'] = False | |
| results['errors'].append(f"Found {invalid_paths} invalid image paths") | |
| if fix: | |
| # Remove rows with invalid paths | |
| df = df[df['image'].notna() & (df['image'] != '')].copy() | |
| results['fixes_applied'].append(f"Removed {invalid_paths} rows with invalid paths") | |
| # Format output | |
| if format == 'json': | |
| import json | |
| console.print(json.dumps(results, indent=2, default=str)) | |
| else: # table | |
| if results['is_valid']: | |
| console.print("[green]✓ CSV file is valid[/green]") | |
| else: | |
| console.print("[red]✗ CSV file has issues[/red]") | |
| if results['errors']: | |
| console.print("\n[bold red]Errors:[/bold red]") | |
| for error in results['errors']: | |
| console.print(f" - {error}") | |
| if results['warnings']: | |
| console.print("\n[bold yellow]Warnings:[/bold yellow]") | |
| for warning in results['warnings']: | |
| console.print(f" - {warning}") | |
| if results['fixes_applied']: | |
| console.print("\n[bold green]Fixes Applied:[/bold green]") | |
| for fix_applied in results['fixes_applied']: | |
| console.print(f" - {fix_applied}") | |
| # Save fixed CSV if requested | |
| if fix and results['fixes_applied'] and output_csv: | |
| try: | |
| save_split_csv(df, output_csv) | |
| console.print(f"\n[green]✓ Saved fixed CSV to: {output_csv}[/green]") | |
| except Exception as e: | |
| console.print(f"[red]Error saving fixed CSV: {e}[/red]") | |
| # Save report if requested | |
| if output_report: | |
| import json | |
| with open(output_report, 'w') as f: | |
| json.dump(results, f, indent=2, default=str) | |
| console.print(f"\n[dim]Validation report saved to: {output_report}[/dim]") | |
| # Exit with error code if invalid | |
| if not results['is_valid']: | |
| sys.exit(1) | |
| @split.command('apply') | |
| @click.option('--csv', 'csv_file', type=str, required=True, help='Pool CSV file to split') | |
| @click.option('--output', 'output_csv', type=str, required=True, help='Output CSV file path') | |
| @click.option('--strategy', type=click.Choice(['random', 'ratio', 'leave-k-out', 'k-fold']), required=True, | |
| help='Splitting strategy') | |
| @click.option('--train-ratio', type=float, help='Train ratio (for random/ratio strategy)') | |
| @click.option('--val-ratio', type=float, help='Validation ratio (for random/ratio strategy)') | |
| @click.option('--test-ratio', type=float, help='Test ratio (for random/ratio strategy)') | |
| @click.option('--k', type=int, help='Number of samples to leave out (for leave-k-out)') | |
| @click.option('--n-folds', type=int, help='Number of folds (for k-fold)') | |
| @click.option('--fold-idx', type=int, default=0, help='Fold index to use as test (for k-fold)') | |
| @click.option('--from-split', type=str, default='pool', help='Source split (for leave-k-out)') | |
| @click.option('--to-split', type=str, default='val', help='Target split (for leave-k-out)') | |
| @click.option('--seed', type=int, default=42, help='Random seed') | |
| @click.option('--dry-run', is_flag=True, help='Show what would be created without modifying files') | |
| def split_apply(csv_file, output_csv, strategy, train_ratio, val_ratio, test_ratio, | |
| k, n_folds, fold_idx, from_split, to_split, seed, dry_run): | |
| """Apply splitting strategy to convert pool CSV into train/val/test splits. | |
| Examples: | |
| # Random split with ratios | |
| hae dataset split apply \\ | |
| --csv splits/mvtec-bottle-pool.csv \\ | |
| --output splits/mvtec-bottle-split.csv \\ | |
| --strategy random \\ | |
| --train-ratio 0.7 --val-ratio 0.15 --test-ratio 0.15 \\ | |
| --seed 42 | |
| # Leave-k-out | |
| hae dataset split apply \\ | |
| --csv splits/mvtec-bottle-pool.csv \\ | |
| --output splits/mvtec-bottle-split.csv \\ | |
| --strategy leave-k-out \\ | |
| --k 20 \\ | |
| --from-split pool \\ | |
| --to-split val | |
| """ | |
| from pathlib import Path | |
| import sys | |
| # Import utilities | |
| sys.path.insert(0, str(Path(__file__).parent.parent.parent / "dioodmi-py" / "src")) | |
| try: | |
| from dioodmi.data.split_utils import load_split_csv, save_split_csv | |
| from dioodmi.data.split_strategies import ( | |
| apply_random_split, apply_leave_k_out_split, apply_k_fold_split | |
| ) | |
| except ImportError as e: | |
| console.print(f"[red]Error importing modules: {e}[/red]") | |
| sys.exit(1) | |
| # Load CSV | |
| try: | |
| df = load_split_csv(csv_file) | |
| console.print(f"[dim]Loaded pool CSV: {len(df)} samples[/dim]") | |
| except Exception as e: | |
| console.print(f"[red]Error loading CSV: {e}[/red]") | |
| sys.exit(1) | |
| # Apply strategy | |
| if strategy == 'random' or strategy == 'ratio': | |
| if train_ratio is None or val_ratio is None or test_ratio is None: | |
| console.print("[red]Error: --train-ratio, --val-ratio, and --test-ratio required for random/ratio strategy[/red]") | |
| sys.exit(1) | |
| if abs(train_ratio + val_ratio + test_ratio - 1.0) > 1e-6: | |
| console.print("[red]Error: Ratios must sum to 1.0[/red]") | |
| sys.exit(1) | |
| ratios = {'train': train_ratio, 'val': val_ratio, 'test': test_ratio} | |
| result_df = apply_random_split(df, ratios, seed=seed, respect_existing=False) | |
| elif strategy == 'leave-k-out': | |
| if k is None: | |
| console.print("[red]Error: --k required for leave-k-out strategy[/red]") | |
| sys.exit(1) | |
| result_df = apply_leave_k_out_split( | |
| df, k=k, from_split=from_split, to_split=to_split, | |
| seed=seed, respect_existing=False | |
| ) | |
| elif strategy == 'k-fold': | |
| if n_folds is None: | |
| console.print("[red]Error: --n-folds required for k-fold strategy[/red]") | |
| sys.exit(1) | |
| result_df = apply_k_fold_split( | |
| df, n_folds=n_folds, fold_idx=fold_idx, seed=seed, | |
| respect_existing=False | |
| ) | |
| # Show results | |
| splits = result_df['split'].value_counts().to_dict() | |
| console.print("\n[bold]Split distribution:[/bold]") | |
| for split, count in sorted(splits.items()): | |
| console.print(f" {split}: {count}") | |
| if dry_run: | |
| console.print("\n[bold]Dry run - no changes made[/bold]") | |
| return | |
| # Save | |
| try: | |
| save_split_csv(result_df, output_csv) | |
| console.print(f"\n[green]✓ Created split CSV: {output_csv}[/green]") | |
| console.print(f" Total samples: {len(result_df)}") | |
| except Exception as e: | |
| console.print(f"[red]Error saving CSV: {e}[/red]") | |
| sys.exit(1) | |
| @dataset.command('show') | |
| @click.argument('dataset_name') | |
| @click.option('--object-category', type=str, required=True, help='Object category to show details for') | |
| @click.option('--data-dir', type=str, help='Data directory path (default: ./datasets/raw/MVTec/mvtec_ad)') | |
| @click.option('--csv-split-file', type=str, help='CSV split file path') | |
| def dataset_show(dataset_name, object_category, data_dir, csv_split_file): | |
| """Show detailed information about a specific dataset and category. | |
| Displays comprehensive statistics including: | |
| - Dataset configuration | |
| - Split breakdowns (train/val/test) | |
| - CSV split file details | |
| - Sample paths and counts | |
| - Configuration parameters | |
| Examples: | |
| hae dataset show mvtec --object-category bottle | |
| hae dataset show mvtec --object-category bottle --data-dir ./my-data | |
| hae dataset show mvtec --object-category bottle --csv-split-file splits/custom.csv | |
| """ | |
| from pathlib import Path | |
| import sys | |
| import os | |
| # Import dataset utilities | |
| sys.path.insert(0, str(Path(__file__).parent.parent.parent / "dioodmi-py" / "src")) | |
| try: | |
| from dioodmi.data.dataset_configs import BASE_DATASET_REGISTRY, get_base_dataset_config | |
| from dioodmi.data.anomaly_core import DatasetConfig, ProcessingParams, parse_csv_data | |
| except ImportError as e: | |
| console.print(f"[red]Error importing dataset modules: {e}[/red]") | |
| console.print("[yellow]Make sure dioodmi-py is properly installed[/yellow]") | |
| sys.exit(1) | |
| # Validate dataset | |
| if dataset_name not in BASE_DATASET_REGISTRY: | |
| console.print(f"[red]Unknown dataset: {dataset_name}[/red]") | |
| console.print(f"[yellow]Available datasets: {', '.join(BASE_DATASET_REGISTRY.keys())}[/yellow]") | |
| sys.exit(1) | |
| base_config = BASE_DATASET_REGISTRY[dataset_name] | |
| # Validate object category | |
| if object_category not in base_config.object_classes: | |
| console.print(f"[red]Unknown category '{object_category}' for dataset '{dataset_name}'[/red]") | |
| console.print(f"[yellow]Available categories: {', '.join(base_config.object_classes.keys())}[/yellow]") | |
| sys.exit(1) | |
| # Default data directory | |
| if not data_dir: | |
| data_dir = './datasets/raw/MVTec/mvtec_ad' | |
| # Use provided CSV split file or default | |
| csv_file = csv_split_file if csv_split_file else base_config.csv_split_file | |
| # Display header | |
| console.print(f"\n[bold cyan]Dataset: {dataset_name} / Category: {object_category}[/bold cyan]") | |
| # Show configuration | |
| console.print(f"\n[bold]Configuration:[/bold]") | |
| console.print(f" Dataset Name: {base_config.name}") | |
| console.print(f" Object Category: {object_category}") | |
| console.print(f" Category Index: {base_config.object_classes[object_category]}") | |
| console.print(f" Data Directory: {data_dir}") | |
| console.print(f" CSV Split File: {csv_file}") | |
| # Get original image dimensions from actual files | |
| original_dims = None | |
| try: | |
| # Try to load a sample image to get dimensions | |
| config = DatasetConfig( | |
| name=base_config.name, | |
| object_classes=base_config.object_classes, | |
| csv_columns=base_config.csv_columns, | |
| default_image_size=base_config.default_image_size, | |
| default_crop_size=base_config.default_crop_size, | |
| csv_split_file=csv_file, | |
| strategies=base_config.strategies, | |
| temporal_strategies=base_config.temporal_strategies, | |
| mode='train', | |
| object_class=object_category, | |
| rootdir=data_dir, | |
| anomaly_class='good' | |
| ) | |
| params = ProcessingParams( | |
| image_size=base_config.default_image_size, | |
| crop_size=base_config.default_crop_size | |
| ) | |
| data_df = parse_csv_data(config, params) | |
| if len(data_df) > 0: | |
| # Get image column name | |
| image_col = None | |
| for col in ['image', 'image_path', 'filepath', 'path']: | |
| if col in data_df.columns: | |
| image_col = col | |
| break | |
| if image_col: | |
| # Try to load first image to get dimensions | |
| first_image_path = data_df[image_col].iloc[0] | |
| # Resolve path relative to data directory | |
| if not os.path.isabs(first_image_path): | |
| # Path in CSV is relative, try multiple locations | |
| possible_paths = [ | |
| os.path.join(data_dir, first_image_path), | |
| os.path.join(data_dir, base_config.name, first_image_path), | |
| first_image_path # Try as-is | |
| ] | |
| else: | |
| possible_paths = [first_image_path] | |
| for full_path in possible_paths: | |
| full_path_abs = os.path.abspath(full_path) | |
| if os.path.exists(full_path_abs): | |
| try: | |
| from PIL import Image | |
| with Image.open(full_path_abs) as img: | |
| original_dims = img.size # (width, height) | |
| break | |
| except Exception as e: | |
| # Silently continue to next path | |
| continue | |
| except Exception as e: | |
| # Silently fail - original dimensions are optional | |
| pass | |
| # Show image dimensions | |
| console.print(f"\n[bold]Image Dimensions:[/bold]") | |
| if original_dims: | |
| console.print(f" Original Size: {original_dims[0]}×{original_dims[1]} (width×height)") | |
| else: | |
| console.print(f" Original Size: [dim]Unable to determine (check data directory)[/dim]") | |
| console.print(f" Training Config:") | |
| console.print(f" - Image Size: {base_config.default_image_size} (resize target)") | |
| console.print(f" - Crop Size: {base_config.default_crop_size} (crop target)") | |
| # Show split statistics | |
| console.print(f"\n[bold]Split Statistics:[/bold]") | |
| split_stats = {} | |
| for split in ['train', 'val', 'test']: | |
| try: | |
| # Create dataset config | |
| config = DatasetConfig( | |
| name=base_config.name, | |
| object_classes=base_config.object_classes, | |
| csv_columns=base_config.csv_columns, | |
| default_image_size=base_config.default_image_size, | |
| default_crop_size=base_config.default_crop_size, | |
| csv_split_file=csv_file, | |
| strategies=base_config.strategies, | |
| temporal_strategies=base_config.temporal_strategies, | |
| mode=split, | |
| object_class=object_category, | |
| rootdir=data_dir, | |
| anomaly_class='good' if split == 'train' else 'all' | |
| ) | |
| # Create processing params | |
| params = ProcessingParams( | |
| image_size=base_config.default_image_size, | |
| crop_size=base_config.default_crop_size | |
| ) | |
| # Load dataset metadata | |
| data_df = parse_csv_data(config, params) | |
| count = len(data_df) | |
| split_stats[split] = { | |
| 'count': count, | |
| 'dataframe': data_df | |
| } | |
| # Show anomaly class breakdown for test split | |
| if split == 'test' and 'category' in data_df.columns: | |
| anomaly_counts = data_df['category'].value_counts().to_dict() | |
| console.print(f" {split.capitalize()}: {count} images") | |
| for anomaly_class, anomaly_count in sorted(anomaly_counts.items()): | |
| console.print(f" - {anomaly_class}: {anomaly_count}") | |
| else: | |
| console.print(f" {split.capitalize()}: {count} images") | |
| except Exception as e: | |
| console.print(f" {split.capitalize()}: [red]Error - {str(e)}[/red]") | |
| split_stats[split] = {'count': None, 'error': str(e)} | |
| # Show sample paths (first few from train split) | |
| if 'train' in split_stats and split_stats['train']['count'] and split_stats['train']['count'] > 0: | |
| train_df = split_stats['train']['dataframe'] | |
| # Check for common column names for image paths | |
| image_col = None | |
| for col in ['image', 'image_path', 'filepath', 'path']: | |
| if col in train_df.columns: | |
| image_col = col | |
| break | |
| if image_col: | |
| console.print(f"\n[bold]Sample Image Paths (train, first 5):[/bold]") | |
| for i, path in enumerate(train_df[image_col].head(5), 1): | |
| console.print(f" {i}. {path}") | |
| if len(train_df) > 5: | |
| console.print(f" ... and {len(train_df) - 5} more") | |
| # Show total statistics | |
| total_images = sum(s['count'] for s in split_stats.values() if s.get('count') is not None) | |
| console.print(f"\n[bold]Total Images:[/bold] {total_images}") | |
| # Show CSV file location | |
| csv_path = Path(csv_file) | |
| if not csv_path.is_absolute(): | |
| # Try relative to splits directory | |
| splits_path = Path("splits") / csv_file | |
| if splits_path.exists(): | |
| csv_path = splits_path | |
| else: | |
| csv_path = Path(csv_file) | |
| if csv_path.exists(): | |
| console.print(f"\n[bold]CSV Split File:[/bold] {csv_path.absolute()}") | |
| console.print(f"[dim]File exists: Yes[/dim]") | |
| else: | |
| console.print(f"\n[bold]CSV Split File:[/bold] {csv_file}") | |
| console.print(f"[yellow]File exists: No (may use default dataset structure)[/yellow]") | |
| @experiment.command('clean') | |
| @click.argument('experiment_id') | |
| @click.option('--all', is_flag=True, help='Clean all results (WARNING: destructive)') | |
| def experiment_clean(experiment_id, all): | |
| """Clean temporary files or all results. | |
| By default, cleans temporary files (__pycache__, *.tmp, .ipynb_checkpoints). | |
| Use --all to delete all execution results (WARNING: destructive). | |
| Examples: | |
| hae experiment clean train-and-compare-evaluators | |
| hae experiment clean train-and-compare-evaluators --all | |
| """ | |
| import shutil | |
| # For orchestration format experiments, clean execution directories | |
| executions_dir = Path("experiment-history") | |
| if executions_dir.exists(): | |
| # Find all executions matching this experiment ID | |
| matching_execs = [] | |
| exp_id_norm = experiment_id.replace("-", "_").replace(" ", "_") | |
| for exec_dir in executions_dir.rglob("*"): | |
| if exec_dir.is_dir() and "_" in exec_dir.name: | |
| parts = exec_dir.name.split("_", 1) | |
| if len(parts) == 2: | |
| exec_name = parts[1] | |
| exec_name_norm = exec_name.replace("-", "_").replace(" ", "_") | |
| if exec_name_norm == exp_id_norm or exec_name == experiment_id: | |
| matching_execs.append(exec_dir) | |
| if matching_execs: | |
| if all: | |
| if click.confirm(f"⚠️ Delete ALL {len(matching_execs)} execution(s) for {experiment_id}?", abort=True): | |
| for exec_dir in matching_execs: | |
| shutil.rmtree(exec_dir) | |
| console.print(f"[green]✅ Deleted {len(matching_execs)} execution(s)[/green]") | |
| return | |
| else: | |
| # Clean temporary files only | |
| cleaned_count = 0 | |
| for exec_dir in matching_execs: | |
| for pattern in ["**/__pycache__", "**/*.tmp", "**/.ipynb_checkpoints"]: | |
| for path in exec_dir.glob(pattern): | |
| if path.is_dir(): | |
| shutil.rmtree(path) | |
| cleaned_count += 1 | |
| else: | |
| path.unlink() | |
| cleaned_count += 1 | |
| if cleaned_count > 0: | |
| console.print(f"[green]✅ Cleaned {cleaned_count} temporary file(s) from {len(matching_execs)} execution(s)[/green]") | |
| else: | |
| console.print(f"[dim]No temporary files found in {len(matching_execs)} execution(s)[/dim]") | |
| return | |
| # Fallback: try alternative experiment directory structure | |
| exp_dir = Path("experiments") / experiment_id | |
| if exp_dir.exists(): | |
| if all: | |
| if click.confirm(f"⚠️ Delete ALL results for {experiment_id}?", abort=True): | |
| results_dir = exp_dir / "results" | |
| if results_dir.exists(): | |
| shutil.rmtree(results_dir) | |
| console.print("[green]✅ All results deleted[/green]") | |
| else: | |
| # Clean temporary files only | |
| cleaned_count = 0 | |
| for pattern in ["**/__pycache__", "**/*.tmp", "**/.ipynb_checkpoints"]: | |
| for path in exp_dir.glob(pattern): | |
| if path.is_dir(): | |
| shutil.rmtree(path) | |
| cleaned_count += 1 | |
| else: | |
| path.unlink() | |
| cleaned_count += 1 | |
| if cleaned_count > 0: | |
| console.print(f"[green]✅ Cleaned {cleaned_count} temporary file(s)[/green]") | |
| else: | |
| console.print("[dim]No temporary files found[/dim]") | |
| else: | |
| console.print(f"[yellow]No experiment found: {experiment_id}[/yellow]") | |
| console.print(f"[dim]Tried:[/dim]") | |
| console.print(f" - Execution directories in experiment-history/") | |
| console.print(f" - Alternative experiment directory: {exp_dir}") | |
| @experiment.command('sync') | |
| @click.option('--full', is_flag=True, help='Full rebuild (clears and re-indexes everything)') | |
| @click.option('--verify', is_flag=True, help='Verify database integrity') | |
| def experiment_sync(full, verify): | |
| """Sync experiment database with YAML files and execution history. | |
| By default, performs smart incremental sync (only changed files). | |
| Safe to run frequently (e.g., after git pull). | |
| The database can always be rebuilt from YAML files (ground truth). | |
| Feel free to delete .hae/ directory - sync will recreate it. | |
| Examples: | |
| hae experiment sync # Smart sync (fast, incremental) | |
| hae experiment sync --full # Full rebuild (slower) | |
| hae experiment sync --verify # Check database integrity | |
| """ | |
| from hae.db import ExperimentDB | |
| from hae.migration import sync_database_incremental | |
| db = ExperimentDB() | |
| if verify: | |
| console.print("[bold]Verifying database integrity...[/bold]") | |
| db.verify() | |
| db.close() | |
| return | |
| if full: | |
| # Full rebuild - clear and re-index everything | |
| console.print("[bold]Full database rebuild...[/bold]") | |
| db.cursor.execute("DELETE FROM experiment_definitions") | |
| db.cursor.execute("DELETE FROM experiment_executions") | |
| db.cursor.execute("DELETE FROM links") | |
| db.conn.commit() | |
| console.print(" [dim]Indexing definitions...[/dim]") | |
| db.index_definitions() | |
| console.print(" [dim]Indexing execution history...[/dim]") | |
| db.index_executions() | |
| def_count = db.cursor.execute("SELECT COUNT(*) FROM experiment_definitions").fetchone()[0] | |
| exec_count = db.cursor.execute("SELECT COUNT(*) FROM experiment_executions").fetchone()[0] | |
| console.print(f"[green]✓ Full rebuild complete ({def_count} definitions, {exec_count} executions)[/green]") | |
| else: | |
| # Smart incremental sync | |
| console.print("[bold]Syncing experiment database...[/bold]") | |
| # Check if database is empty (first run or deleted) | |
| def_count = db.cursor.execute("SELECT COUNT(*) FROM experiment_definitions").fetchone()[0] | |
| if def_count == 0: | |
| # Empty database - do full index | |
| console.print(" [dim]Database empty - performing full index...[/dim]") | |
| db.index_definitions() | |
| db.index_executions() | |
| def_count = db.cursor.execute("SELECT COUNT(*) FROM experiment_definitions").fetchone()[0] | |
| exec_count = db.cursor.execute("SELECT COUNT(*) FROM experiment_executions").fetchone()[0] | |
| console.print(f"[green]✓ Indexed {def_count} definitions, {exec_count} executions[/green]") | |
| else: | |
| # Incremental sync - only changed files | |
| new, updated, unchanged = sync_database_incremental(db) | |
| # Also sync executions (always incremental) | |
| old_exec_count = db.cursor.execute("SELECT COUNT(*) FROM experiment_executions").fetchone()[0] | |
| db.index_executions() | |
| new_exec_count = db.cursor.execute("SELECT COUNT(*) FROM experiment_executions").fetchone()[0] | |
| new_execs = new_exec_count - old_exec_count | |
| console.print(f"[green]✓ Definitions:[/green] {new} new, {updated} updated, {unchanged} unchanged") | |
| if new_execs > 0: | |
| console.print(f"[green]✓ Executions:[/green] {new_execs} new") | |
| else: | |
| console.print(f"[dim] Executions: No new records[/dim]") | |
| console.print(f"\n[dim]Database: {db.db_path}[/dim]") | |
| db.close() | |
| @experiment.command('doctor') | |
| @click.option('--migration-only', is_flag=True, help='Check only migration status (quick)') | |
| @click.option('--fix', is_flag=True, help='Auto-fix issues where possible') | |
| @click.option('--verbose', is_flag=True, help='Show detailed diagnostics') | |
| def experiment_doctor(migration_only, fix, verbose): | |
| """Comprehensive health check for experiment system. | |
| Checks migration status, validates experiments, and verifies database. | |
| This is the recommended command to run after pulling new changes. | |
| Examples: | |
| hae experiment doctor # Full health check | |
| hae experiment doctor --migration-only # Quick migration status | |
| hae experiment doctor --fix # Auto-repair issues | |
| hae experiment doctor --verbose # Detailed diagnostics | |
| """ | |
| from hae.migration import ( | |
| check_migration_status, | |
| find_old_path_references, | |
| validate_yaml_syntax, | |
| validate_dependencies, | |
| validate_file_references, | |
| fix_yaml_paths | |
| ) | |
| from hae.db import ExperimentDB | |
| from rich.table import Table | |
| console.print("[bold cyan]🏥 DiOodMi Experiment Health Check[/bold cyan]\n") | |
| # Track overall status | |
| errors = 0 | |
| warnings = 0 | |
| # ======================================================================== | |
| # 1. Migration Status Check | |
| # ======================================================================== | |
| console.print("[bold]Migration Status:[/bold]") | |
| status = check_migration_status() | |
| if status.old_definitions_exists: | |
| console.print(f" [red]❌ OLD:[/red] experiments/definitions/ ({status.old_definitions_count} files)") | |
| errors += 1 | |
| if status.old_history_exists: | |
| console.print(f" [red]❌ OLD:[/red] experiments/history/ ({status.old_history_count} records)") | |
| errors += 1 | |
| if status.new_definitions_exists: | |
| console.print(f" [green]✅ NEW:[/green] definitions/ ({status.new_definitions_count} files)") | |
| else: | |
| console.print(f" [yellow]⚠️ NEW:[/yellow] definitions/ not found") | |
| warnings += 1 | |
| if status.new_history_exists: | |
| console.print(f" [dim] NEW: experiment-history/ ({status.new_history_count} records)[/dim]") | |
| if status.needs_migration: | |
| console.print(f" [cyan]💡 Action:[/cyan] Run 'hae experiment migrate'\n") | |
| else: | |
| console.print(f" [green]✓ Migration complete[/green]\n") | |
| # If migration-only flag, stop here | |
| if migration_only: | |
| summary = f"Migration {'needed' if status.needs_migration else 'complete'}" | |
| console.print(f"[dim]{summary}[/dim]") | |
| sys.exit(1 if status.needs_migration else 0) | |
| # ======================================================================== | |
| # 2. Experiment Definitions Validation | |
| # ======================================================================== | |
| console.print("[bold]Experiment Definitions:[/bold]") | |
| yaml_dir = Path("definitions/experiments") | |
| if not yaml_dir.exists(): | |
| console.print(" [yellow]⚠️ definitions/experiments/ not found[/yellow]\n") | |
| warnings += 1 | |
| else: | |
| # YAML syntax validation | |
| yaml_files = list(yaml_dir.glob("*.yaml")) | |
| valid_count = 0 | |
| invalid_files = [] | |
| for yaml_file in yaml_files: | |
| result = validate_yaml_syntax(yaml_file) | |
| if result.valid: | |
| valid_count += 1 | |
| else: | |
| invalid_files.append((yaml_file.name, result.error)) | |
| if invalid_files: | |
| console.print(f" [red]❌ YAML Syntax ({valid_count}/{len(yaml_files)} valid):[/red]") | |
| for filename, error in invalid_files[:5]: # Show first 5 | |
| error_short = error[:60] + "..." if len(error) > 60 else error | |
| console.print(f" • {filename}: {error_short}") | |
| if len(invalid_files) > 5: | |
| console.print(f" ... and {len(invalid_files) - 5} more") | |
| errors += len(invalid_files) | |
| else: | |
| console.print(f" [green]✅ YAML Syntax ({len(yaml_files)}/{len(yaml_files)} valid)[/green]") | |
| # Dependency validation | |
| dep_issues = validate_dependencies(yaml_dir) | |
| if dep_issues: | |
| console.print(f" [red]❌ Dependencies ({len(dep_issues)} issues):[/red]") | |
| for issue in dep_issues[:5]: # Show first 5 | |
| console.print(f" • {issue.experiment_id} → {issue.missing_dependency} ({issue.dependency_type} not found)") | |
| if len(dep_issues) > 5: | |
| console.print(f" ... and {len(dep_issues) - 5} more") | |
| errors += len(dep_issues) | |
| else: | |
| console.print(f" [green]✅ Dependencies (no broken references)[/green]") | |
| # File reference validation | |
| file_issues = validate_file_references(yaml_dir) | |
| if file_issues: | |
| console.print(f" [yellow]⚠️ Missing Files ({len(file_issues)} warnings):[/yellow]") | |
| for issue in file_issues[:5]: # Show first 5 | |
| console.print(f" • {issue.missing_file} (ref in {issue.experiment_id})") | |
| if len(file_issues) > 5: | |
| console.print(f" ... and {len(file_issues) - 5} more") | |
| warnings += len(file_issues) | |
| else: | |
| console.print(f" [green]✅ File References (all files found)[/green]") | |
| # Old path references | |
| path_issues = find_old_path_references(yaml_dir) | |
| if path_issues: | |
| console.print(f" [yellow]⚠️ Old Path References ({len(path_issues)} found):[/yellow]") | |
| for issue in path_issues[:3]: # Show first 3 | |
| console.print(f" • {issue.file.name}:{issue.line} ({issue.old_path})") | |
| if len(path_issues) > 3: | |
| console.print(f" ... and {len(path_issues) - 3} more") | |
| warnings += len(path_issues) | |
| if fix: | |
| console.print(f" [cyan]🔧 Fixing path references...[/cyan]") | |
| fixed_count, actions = fix_yaml_paths(yaml_dir, dry_run=False) | |
| console.print(f" [green]✓ Fixed {fixed_count} files[/green]") | |
| warnings -= len(path_issues) # Remove warnings since we fixed them | |
| console.print() | |
| # ======================================================================== | |
| # 3. Database Status | |
| # ======================================================================== | |
| console.print("[bold]Database:[/bold]") | |
| db_path = Path(".hae/experiments.db") | |
| if not db_path.exists(): | |
| console.print(" [yellow]⚠️ .hae/experiments.db not found[/yellow]") | |
| console.print(" [cyan]💡 Action:[/cyan] Run 'hae experiment sync'\n") | |
| warnings += 1 | |
| else: | |
| try: | |
| db = ExperimentDB() | |
| def_count = db.cursor.execute("SELECT COUNT(*) FROM experiment_definitions").fetchone()[0] | |
| exec_count = db.cursor.execute("SELECT COUNT(*) FROM experiment_executions").fetchone()[0] | |
| if def_count == 0 and status.new_definitions_exists: | |
| console.print(f" [yellow]⚠️ Database empty (0 definitions)[/yellow]") | |
| console.print(" [cyan]💡 Action:[/cyan] Run 'hae experiment sync'\n") | |
| warnings += 1 | |
| else: | |
| console.print(f" [green]✅ Database OK ({def_count} definitions, {exec_count} executions)[/green]") | |
| # Integrity check | |
| if verbose: | |
| result = db.cursor.execute("PRAGMA integrity_check").fetchone() | |
| if result[0] == 'ok': | |
| console.print(" [green]✅ Integrity check passed[/green]") | |
| else: | |
| console.print(f" [red]❌ Integrity check failed: {result[0]}[/red]") | |
| errors += 1 | |
| db.close() | |
| console.print() | |
| except Exception as e: | |
| console.print(f" [red]❌ Database error: {e}[/red]\n") | |
| errors += 1 | |
| # ======================================================================== | |
| # 4. Summary | |
| # ======================================================================== | |
| console.print("[bold]Summary:[/bold]") | |
| if errors == 0 and warnings == 0: | |
| console.print("[green]✅ All healthy - no issues found![/green]") | |
| sys.exit(0) | |
| elif errors == 0: | |
| console.print(f"[yellow]⚠️ {warnings} warning(s) found[/yellow]") | |
| sys.exit(0) | |
| else: | |
| console.print(f"[red]❌ {errors} error(s), {warnings} warning(s) found[/red]") | |
| if not fix: | |
| console.print("[dim]Run with --fix to attempt automatic repairs[/dim]") | |
| sys.exit(1) | |
| @experiment.command('migrate') | |
| @click.option('--dry-run', is_flag=True, default=True, help='Preview changes without applying (default)') | |
| @click.option('--apply', is_flag=True, help='Apply migration (bypasses dry-run)') | |
| @click.option('--no-backup', is_flag=True, help='Skip backup creation (not recommended)') | |
| @click.option('--force', is_flag=True, help='Skip confirmation prompts') | |
| def experiment_migrate(dry_run, apply, no_backup, force): | |
| """Migrate experiment structure from old to new format. | |
| This command safely migrates your experiments from the old directory | |
| structure (experiments/definitions/) to the new structure (definitions/). | |
| By default, shows a dry-run preview and asks for confirmation. | |
| Examples: | |
| hae experiment migrate # Preview migration (dry-run) | |
| hae experiment migrate --apply # Apply migration with confirmation | |
| hae experiment migrate --apply --force # Apply without confirmation | |
| """ | |
| from hae.migration import ( | |
| check_migration_status, | |
| create_backup, | |
| migrate_directories, | |
| fix_yaml_paths | |
| ) | |
| from hae.db import ExperimentDB | |
| console.print("[bold cyan]🔄 Experiment Structure Migration[/bold cyan]\n") | |
| # Check if migration is needed | |
| status = check_migration_status() | |
| if not status.needs_migration: | |
| console.print("[green]✓ Already migrated - no action needed![/green]") | |
| console.print("[dim]Current structure:[/dim]") | |
| console.print(f" • definitions/ ({status.new_definitions_count} files)") | |
| console.print(f" • experiment-history/ ({status.new_history_count} records)") | |
| return | |
| # Determine if we're actually applying or just showing dry-run | |
| actually_apply = apply and not dry_run | |
| if actually_apply: | |
| console.print("[bold]Migration Plan:[/bold]\n") | |
| else: | |
| console.print("[bold][DRY RUN] Preview of changes:[/bold]\n") | |
| # ======================================================================== | |
| # 1. Backup | |
| # ======================================================================== | |
| if not no_backup: | |
| if actually_apply: | |
| console.print("1. [bold]Backup:[/bold]") | |
| try: | |
| backup_dir = create_backup() | |
| console.print(f" [green]✓ Created backup: {backup_dir}/[/green]\n") | |
| except Exception as e: | |
| console.print(f" [red]❌ Backup failed: {e}[/red]") | |
| console.print("[yellow]Aborting migration for safety[/yellow]") | |
| sys.exit(1) | |
| else: | |
| timestamp = datetime.now().strftime("%Y%m%d-%H%M%S") | |
| console.print(f"1. [bold]Backup:[/bold]") | |
| console.print(f" ✓ Would create .migration-backup-{timestamp}/\n") | |
| else: | |
| console.print("1. [bold]Backup:[/bold] Skipped (--no-backup)\n") | |
| # ======================================================================== | |
| # 2. Move Directories | |
| # ======================================================================== | |
| console.print("2. [bold]Move Directories:[/bold]") | |
| success, actions = migrate_directories(dry_run=(not actually_apply)) | |
| for action in actions: | |
| if action.startswith("Would"): | |
| console.print(f" {action}") | |
| elif action.startswith("Moved") or action.startswith("Removed"): | |
| console.print(f" [green]✓ {action}[/green]") | |
| elif action.startswith("Error"): | |
| console.print(f" [red]❌ {action}[/red]") | |
| if not success: | |
| console.print("\n[red]Migration failed - check errors above[/red]") | |
| sys.exit(1) | |
| console.print() | |
| # ======================================================================== | |
| # 3. Fix Path References | |
| # ======================================================================== | |
| console.print("3. [bold]Update Path References:[/bold]") | |
| yaml_dir = Path("definitions/experiments") | |
| if yaml_dir.exists(): | |
| fixed_count, fix_actions = fix_yaml_paths(yaml_dir, dry_run=(not actually_apply)) | |
| if fixed_count > 0: | |
| for action in fix_actions[:5]: # Show first 5 | |
| if action.startswith("Would"): | |
| console.print(f" {action}") | |
| elif action.startswith("Fixed"): | |
| console.print(f" [green]✓ {action}[/green]") | |
| elif action.startswith("Error"): | |
| console.print(f" [red]❌ {action}[/red]") | |
| if len(fix_actions) > 5: | |
| console.print(f" ... and {len(fix_actions) - 5} more") | |
| else: | |
| console.print(" [dim]No path references to fix[/dim]") | |
| else: | |
| console.print(" [yellow]⚠️ definitions/experiments/ not found (will be created)[/yellow]") | |
| console.print() | |
| # ======================================================================== | |
| # 4. Initialize Database | |
| # ======================================================================== | |
| console.print("4. [bold]Initialize Database:[/bold]") | |
| if actually_apply: | |
| console.print(" [green]✓ Run 'hae experiment sync' after migration to build database[/green]") | |
| else: | |
| console.print(" ✓ Would run 'hae experiment sync' to build database") | |
| console.print() | |
| # ======================================================================== | |
| # Confirmation & Apply | |
| # ======================================================================== | |
| if not actually_apply: | |
| console.print("[yellow]This was a DRY RUN - no changes were made[/yellow]\n") | |
| if not force: | |
| import click | |
| if click.confirm("Apply migration?", default=False): | |
| console.print("\n[cyan]Applying migration...[/cyan]\n") | |
| # Recursively call with --apply --force | |
| import subprocess | |
| result = subprocess.run( | |
| ["hae", "experiment", "migrate", "--apply", "--force"] + | |
| (["--no-backup"] if no_backup else []), | |
| cwd=Path.cwd() | |
| ) | |
| sys.exit(result.returncode) | |
| else: | |
| console.print("[dim]Migration cancelled - run 'hae experiment migrate --apply' when ready[/dim]") | |
| sys.exit(0) | |
| else: | |
| console.print("[dim]Run with --apply to execute migration[/dim]") | |
| sys.exit(0) | |
| else: | |
| console.print("[green]✅ Migration complete![/green]\n") | |
| console.print("[cyan]Verifying...[/cyan]") | |
| # Run doctor to verify | |
| import subprocess | |
| subprocess.run(["hae", "experiment", "doctor", "--migration-only"], cwd=Path.cwd()) | |
| @main.group() | |
| def compare(): | |
| """Compare files from different runs or implementations. | |
| Supports comparison of images, JSON metrics, NPY arrays, time-series, and text files. | |
| """ | |
| pass | |
| @compare.command('images') | |
| @click.argument('files', nargs=-1, type=click.Path(exists=True)) | |
| @click.option('--pattern', multiple=True, help='Glob pattern to find images (can be used multiple times)') | |
| @click.option('--dirs', multiple=True, help='Directories to search for images (can be used multiple times)') | |
| @click.option('--match-by', type=click.Choice(['name', 'index', 'metadata']), default='name', | |
| help='How to match files across directories (default: name)') | |
| @click.option('--plot', type=click.Choice(['grid', 'diff', 'before-after', 'statistics', 'multi-panel']), | |
| default='grid', help='Plot type (default: grid)') | |
| @click.option('--grid-cols', type=int, help='Number of columns for grid plot') | |
| @click.option('--labels', multiple=True, help='Labels for images (use --labels for each label, or comma-separated)') | |
| @click.option('--title', help='Plot title') | |
| @click.option('--output', type=click.Path(), help='Output file path (for single comparison)') | |
| @click.option('--output-dir', type=click.Path(), help='Output directory (for batch comparisons)') | |
| @click.option('--output-format', type=click.Choice(['png', 'pdf', 'svg', 'html']), default='png', | |
| help='Output format (default: png)') | |
| @click.option('--dpi', type=int, default=100, help='DPI for output (default: 100)') | |
| @click.option('--diff', is_flag=True, help='Show pixel differences') | |
| @click.option('--metrics', multiple=True, help='Metrics to compute (mse, ssim, mean_diff, max_diff)') | |
| @click.option('--verbose', '-v', is_flag=True, help='Verbose output') | |
| @click.option('--dry-run', is_flag=True, help='Show what would be done without executing') | |
| def compare_images(files, pattern, dirs, match_by, plot, grid_cols, labels, title, output, | |
| output_dir, output_format, dpi, diff, metrics, verbose, dry_run): | |
| """Compare image files. | |
| Can compare specific files, files matching patterns, or files in directories. | |
| Examples: | |
| hae compare images img1.png img2.png img3.png --plot grid --labels "A" "B" "C" | |
| hae compare images --pattern "**/anomaly_maps/*.png" --plot grid --output-dir ./comparisons | |
| hae compare images --dirs dir1 dir2 dir3 --match-by name --plot grid --labels "process" "evaluator" "ad" | |
| """ | |
| from .commands.compare import compare_images_cli | |
| # Convert click types to lists | |
| patterns = list(pattern) if pattern else [] | |
| dirs_list = list(dirs) if dirs else None | |
| # Handle labels - support both comma-separated and multiple flags | |
| labels_list = [] | |
| if labels: | |
| for label in labels: | |
| # Support comma-separated labels: --labels "a,b,c" | |
| if ',' in label: | |
| labels_list.extend([l.strip() for l in label.split(',')]) | |
| else: | |
| labels_list.append(label) | |
| metrics_list = list(metrics) if metrics else None | |
| # If files are provided directly, add them as patterns | |
| if files: | |
| patterns.extend([str(f) for f in files]) | |
| try: | |
| compare_images_cli( | |
| patterns=patterns, | |
| dirs=dirs_list, | |
| match_by=match_by, | |
| plot=plot, | |
| grid_cols=grid_cols, | |
| labels=labels_list, | |
| title=title, | |
| output=output, | |
| output_dir=output_dir, | |
| output_format=output_format, | |
| dpi=dpi, | |
| diff=diff, | |
| metrics=metrics_list, | |
| verbose=verbose, | |
| dry_run=dry_run | |
| ) | |
| console.print("[green]✓ Comparison complete[/green]") | |
| except Exception as e: | |
| console.print(f"[red]✗ Error: {e}[/red]") | |
| if verbose: | |
| import traceback | |
| traceback.print_exc() | |
| sys.exit(1) | |
| @compare.command('json') | |
| @click.argument('files', nargs=-1, type=click.Path(exists=True)) | |
| @click.option('--pattern', multiple=True, help='Glob pattern to find JSON files') | |
| @click.option('--dirs', multiple=True, help='Directories to search for JSON files') | |
| @click.option('--extract', multiple=True, help='Extract specific values (e.g., "evaluation.auroc")') | |
| @click.option('--extract-all', is_flag=True, help='Extract all metrics') | |
| @click.option('--diff', is_flag=True, help='Show differences') | |
| @click.option('--table', is_flag=True, help='Show as table') | |
| @click.option('--compute-deltas', is_flag=True, help='Compute differences/deltas') | |
| @click.option('--output', type=click.Path(), help='Output file path') | |
| @click.option('--output-format', type=click.Choice(['csv', 'json', 'markdown', 'table']), default='table', | |
| help='Output format (default: table)') | |
| @click.option('--verbose', '-v', is_flag=True, help='Verbose output') | |
| @click.option('--dry-run', is_flag=True, help='Show what would be done without executing') | |
| def compare_json(files, pattern, dirs, extract, extract_all, diff, table, compute_deltas, | |
| output, output_format, verbose, dry_run): | |
| """Compare JSON files (metrics, results, etc.). | |
| Examples: | |
| hae compare json metrics1.json metrics2.json --extract evaluation.auroc --table | |
| hae compare json --pattern "**/final_metrics.json" --extract-all --output comparison.csv | |
| """ | |
| console.print("[yellow]JSON comparison not yet implemented[/yellow]") | |
| console.print("[dim]This feature is planned but not yet available[/dim]") | |
| @compare.command('npy') | |
| @click.argument('files', nargs=-1, type=click.Path(exists=True)) | |
| @click.option('--pattern', multiple=True, help='Glob pattern to find NPY files') | |
| @click.option('--dirs', multiple=True, help='Directories to search for NPY files') | |
| @click.option('--metrics', multiple=True, help='Metrics to compute (mse, mae, correlation)') | |
| @click.option('--output', type=click.Path(), help='Output file path') | |
| @click.option('--verbose', '-v', is_flag=True, help='Verbose output') | |
| @click.option('--dry-run', is_flag=True, help='Show what would be done without executing') | |
| def compare_npy(files, pattern, dirs, metrics, output, verbose, dry_run): | |
| """Compare NPY array files. | |
| Examples: | |
| hae compare npy array1.npy array2.npy --metrics mse,mae | |
| hae compare npy --pattern "**/*.npy" --metrics mse --output comparison.txt | |
| """ | |
| console.print("[yellow]NPY comparison not yet implemented[/yellow]") | |
| console.print("[dim]This feature is planned but not yet available[/dim]") | |
| @compare.command('time-series') | |
| @click.argument('files', nargs=-1, type=click.Path(exists=True)) | |
| @click.option('--pattern', multiple=True, help='Glob pattern to find time-series files') | |
| @click.option('--metric', help='Metric to extract (e.g., "training.loss")') | |
| @click.option('--x-axis', help='X-axis field (e.g., "epoch")') | |
| @click.option('--plot-type', type=click.Choice(['overlaid', 'separate', 'diff']), default='overlaid', | |
| help='Plot type (default: overlaid)') | |
| @click.option('--labels', multiple=True, help='Labels for series') | |
| @click.option('--output', type=click.Path(), help='Output file path') | |
| @click.option('--verbose', '-v', is_flag=True, help='Verbose output') | |
| @click.option('--dry-run', is_flag=True, help='Show what would be done without executing') | |
| def compare_time_series(files, pattern, metric, x_axis, plot_type, labels, output, verbose, dry_run): | |
| """Compare time-series data. | |
| Examples: | |
| hae compare time-series run1/metrics.json run2/metrics.json --metric training.loss --x-axis epoch | |
| """ | |
| console.print("[yellow]Time-series comparison not yet implemented[/yellow]") | |
| console.print("[dim]This feature is planned but not yet available[/dim]") | |
| @compare.command('text') | |
| @click.argument('files', nargs=-1, type=click.Path(exists=True)) | |
| @click.option('--pattern', multiple=True, help='Glob pattern to find text files') | |
| @click.option('--diff', is_flag=True, help='Line-by-line diff') | |
| @click.option('--extract-patterns', multiple=True, help='Extract values using regex patterns') | |
| @click.option('--summary', is_flag=True, help='Summary comparison') | |
| @click.option('--output', type=click.Path(), help='Output file path') | |
| @click.option('--verbose', '-v', is_flag=True, help='Verbose output') | |
| @click.option('--dry-run', is_flag=True, help='Show what would be done without executing') | |
| def compare_text(files, pattern, diff, extract_patterns, summary, output, verbose, dry_run): | |
| """Compare text/log files. | |
| Examples: | |
| hae compare text log1.txt log2.txt --diff --output diff.txt | |
| hae compare text --pattern "**/*.log" --extract-patterns "loss=([0-9.]+)" --summary | |
| """ | |
| console.print("[yellow]Text comparison not yet implemented[/yellow]") | |
| console.print("[dim]This feature is planned but not yet available[/dim]") | |
| if __name__ == '__main__': | |
| main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| name: iris-eval-ad-detect-and-evaluate-256 | |
| description: | | |
| Complete pipeline that evaluates a DeCo-Diff model using AnomalyDetector for Iris dataset. | |
| sweep: | |
| type: grid | |
| params: | |
| sweep_dataset: [ "80_Real_merged_head", "280_False_merged_defect_modulo_obvious_defect", "280_False_merged_obvious_defect", "280_False_merged_normal_with_perfect_plain_10_with_split_only_test", "280_False_merged_normal_with_perfect_plain_10_with_split_only_train"] | |
| variables: | |
| # Dataset configuration | |
| dataset: iris | |
| data_dir: D:/HDS_Mask/SL_DL_Ori/save/EUV | |
| object_category: all | |
| model_size: UNet_L | |
| image_size: 256 | |
| center_crop: 256 | |
| model_path: D:/HDS/DiOodMi___250917/results/decodiff_iris_basic_augmentation_with_split/checkpoints/epoch-857.pt | |
| csv_split_file: "../splits/iris-split_{sweep_dataset}.csv" # Custom CSV split file | |
| eval_output_dir_ad: "{EXECUTION_DIR}/{sweep_dataset}" | |
| save_plot_path : "{EXECUTION_DIR}/{sweep_dataset}/save_plot_path" | |
| # Evaluation configuration (deterministic settings required!) | |
| crop_size: 256 | |
| patch_size: 256 # For eval-process (processor uses patch_size) | |
| batch_size: 10 | |
| pad_px: 2 | |
| stride: 1 | |
| directions: h v | |
| tta_batch_size: 2 | |
| # Evaluation processing parameters | |
| reverse_steps: 5 # Number of reverse diffusion steps | |
| num_workers: 0 # Number of DataLoader workers (0 for Windows compatibility) | |
| vae_type: ema # VAE type (ema or mse) | |
| shift_method: mirror # TTA shift augmentation method | |
| fuse_method: pct25 # Method to fuse multiple shifted predictions | |
| img_enlarge_px: 4 # Image enlargement for interpolation method | |
| # Processor-specific parameters | |
| annotation_dir: ./annotations # Directory containing annotation JSON files (REQUIRED for eval-process) | |
| split: test # Data split to process (for eval-process) | |
| # Reconstruction saving options (optional) | |
| save_reconstructions_flag: "" # Flag for saving VAE and Deco-Diff reconstructions as PNG | |
| save_reconstructions_npy_flag: "" # Flag for saving raw reconstruction arrays as NPY files | |
| save_raw_anomaly_maps_flag: "" # Flag for saving raw anomaly map arrays as NPY files | |
| # Visualization variant options (optional) | |
| save_image_variants_flag: "" # Variants to save: continuous binary grayscale absolute | |
| save_colormaps_flag: "" # Colormaps: jet hot viridis plasma magma inferno turbo | |
| save_normalization: minmax # Normalization method: minmax zscore percentile raw | |
| normal_center: 125.0 # Center value for grayscale normalization | |
| save_binary_thresholds_flag: "" # Thresholds: e.g., "--save-binary-thresholds 5.0 10.0 15.0" | |
| # Dataset sampling options (optional - set to null to use full dataset) | |
| bootstrap_samples: null | |
| bootstrap_seed: 42 | |
| last_n: null | |
| range_start: null | |
| range_count: 10 | |
| random_n: null | |
| random_n_with_replacement: null | |
| random_seed: 42 | |
| # Backend settings for determinism | |
| torch_deterministic: true | |
| cudnn_benchmark: false | |
| allow_tf32: false | |
| experiments: | |
| # =================================================================== | |
| # Step 1: Detection Only (Save TTA fundamentals to NPY cache) | |
| # =================================================================== | |
| - name: detect-and-evaluate | |
| description: "Detection and evaluate" | |
| pipeline: ../pipelines/eval-ad.yaml | |
| variables: | |
| mode: "detect_and_evaluate" | |
| model_path: "{model_path}" | |
| dataset: "{dataset}" | |
| data_dir: "{data_dir}" | |
| object_category: "{object_category}" | |
| csv_split_file: "{csv_split_file}" | |
| model_size: "{model_size}" | |
| crop_size: "{crop_size}" | |
| batch_size: "{batch_size}" | |
| pad_px: "{pad_px}" | |
| tta_batch_size: "{tta_batch_size}" | |
| stride: "{stride}" | |
| directions: "{directions}" | |
| shift_method: "{shift_method}" | |
| fuse_method: "{fuse_method}" | |
| img_enlarge_px: "{img_enlarge_px}" | |
| reverse_steps: "{reverse_steps}" | |
| num_workers: "{num_workers}" | |
| vae_type: "{vae_type}" | |
| output_dir: "{eval_output_dir_ad}" | |
| # Reconstruction saving options | |
| save_reconstructions_flag: "{save_reconstructions_flag}" | |
| save_reconstructions_npy_flag: "{save_reconstructions_npy_flag}" | |
| save_raw_anomaly_maps_flag: "{save_raw_anomaly_maps_flag}" | |
| # Visualization options | |
| save_image_variants_flag: "{save_image_variants_flag}" | |
| save_colormaps_flag: "{save_colormaps_flag}" | |
| save_normalization: "{save_normalization}" | |
| normal_center: "{normal_center}" | |
| save_binary_thresholds_flag: "{save_binary_thresholds_flag}" | |
| # Evaluation sampling options (set to null to use full dataset) | |
| bootstrap_samples: "{bootstrap_samples}" | |
| bootstrap_seed: "{bootstrap_seed}" | |
| first_n: "{eval_first_n}" | |
| last_n: "{last_n}" | |
| range_start: "{range_start}" | |
| range_count: "{range_count}" | |
| random_n: "{random_n}" | |
| random_n_with_replacement: "{random_n_with_replacement}" | |
| random_seed: "{random_seed}" | |
| backend: | |
| type: local | |
| parallel: false # Run sequentially: detect-only → evaluate-only | |
| continue_on_error: false # Stop if detection fails |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment