Created
January 29, 2026 08:53
-
-
Save AlexandreAbraham/540991389d8091ae3a12868ecaf2daf0 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| """ | |
| Comprehensive RoRA-Tab Benchmark with: | |
| - 5-fold stratified cross-validation | |
| - Real TabICL/TabPFN predictions | |
| - All 6 RoRA configs trained to convergence + ensemble | |
| - Reports individual config accuracies AND ensemble | |
| - Non-trivial datasets (madeline, etc.) | |
| - Tuned XGBoost baseline | |
| """ | |
| import sys | |
| sys.path.insert(0, '/Users/aabraham/rora-tab') | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| import torch.nn as nn | |
| from torch.utils.data import DataLoader, TensorDataset | |
| from sklearn.model_selection import StratifiedKFold, train_test_split | |
| from sklearn.preprocessing import LabelEncoder, StandardScaler | |
| from sklearn.impute import SimpleImputer | |
| from sklearn.datasets import fetch_openml | |
| import xgboost as xgb | |
| from scipy import stats | |
| import warnings | |
| import time | |
| import json | |
| import hashlib | |
| import os | |
| from pathlib import Path | |
| from typing import Dict, List, Tuple, Optional | |
| warnings.filterwarnings('ignore') | |
| # ============================================================================= | |
| # CACHING UTILITIES | |
| # ============================================================================= | |
| CACHE_DIR = Path(__file__).parent / "cache" | |
| def get_cache_key(dataset_name: str, fold: int, method: str, seed: int) -> str: | |
| """Generate a unique cache key for a result.""" | |
| return f"{dataset_name}_fold{fold}_{method}_seed{seed}" | |
| def get_cache_path(cache_key: str) -> Path: | |
| """Get the file path for a cached result.""" | |
| CACHE_DIR.mkdir(exist_ok=True) | |
| return CACHE_DIR / f"{cache_key}.json" | |
| def load_cached_result(dataset_name: str, fold: int, method: str, seed: int) -> Optional[Dict]: | |
| """Load a cached result if it exists.""" | |
| cache_key = get_cache_key(dataset_name, fold, method, seed) | |
| cache_path = get_cache_path(cache_key) | |
| if cache_path.exists(): | |
| try: | |
| with open(cache_path, 'r') as f: | |
| return json.load(f) | |
| except (json.JSONDecodeError, IOError): | |
| return None | |
| return None | |
| def save_cached_result(dataset_name: str, fold: int, method: str, seed: int, result: Dict): | |
| """Save a result to cache.""" | |
| cache_key = get_cache_key(dataset_name, fold, method, seed) | |
| cache_path = get_cache_path(cache_key) | |
| CACHE_DIR.mkdir(exist_ok=True) | |
| with open(cache_path, 'w') as f: | |
| json.dump(result, f) | |
| def clear_cache(dataset_name: Optional[str] = None): | |
| """Clear cached results. If dataset_name is None, clear all.""" | |
| if not CACHE_DIR.exists(): | |
| return | |
| if dataset_name is None: | |
| for f in CACHE_DIR.glob("*.json"): | |
| f.unlink() | |
| print("Cleared all cached results") | |
| else: | |
| for f in CACHE_DIR.glob(f"{dataset_name}_*.json"): | |
| f.unlink() | |
| print(f"Cleared cached results for {dataset_name}") | |
| from rora_tab import RoRATabClassifier | |
| from rora_tab.rora_layer import RoRALayer | |
| from rora_tab.transformer import RandomTransformerEncoder | |
| # Try to import TabICL | |
| try: | |
| from tabicl import TabICLClassifier | |
| TABICL_AVAILABLE = True | |
| except ImportError: | |
| TABICL_AVAILABLE = False | |
| print("Warning: TabICL not installed. Install with: pip install tabicl") | |
| # Try to import TabPFN | |
| try: | |
| from tabpfn import TabPFNClassifier | |
| TABPFN_AVAILABLE = True | |
| except ImportError: | |
| TABPFN_AVAILABLE = False | |
| print("Warning: TabPFN not installed. Install with: pip install tabpfn") | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| print(f"Using device: {device}") | |
| # ============================================================================= | |
| # DATASETS | |
| # ============================================================================= | |
| DATASETS = [ | |
| # Classic benchmarks | |
| {'name': 'credit-g', 'version': 1, 'description': 'German Credit (1000 samples, 20 features)'}, | |
| {'name': 'diabetes', 'version': 1, 'description': 'Pima Indians Diabetes (768 samples, 8 features)'}, | |
| {'name': 'Australian', 'version': 4, 'description': 'Australian Credit (690 samples, 14 features)'}, | |
| {'name': 'vehicle', 'version': 1, 'description': 'Vehicle Silhouettes (846 samples, 18 features)'}, | |
| {'name': 'phoneme', 'version': 1, 'description': 'Phoneme (5404 samples, 5 features)'}, | |
| {'name': 'blood-transfusion-service-center', 'version': 1, 'description': 'Blood Transfusion (748 samples, 4 features)'}, | |
| {'name': 'kc1', 'version': 1, 'description': 'KC1 Software Defect (2109 samples, 21 features)'}, | |
| # Non-trivial / challenging datasets | |
| {'name': 'madeline', 'version': 1, 'description': 'Madeline (3140 samples, 259 features) - High dimensional'}, | |
| {'name': 'madelon', 'version': 1, 'description': 'Madelon (2600 samples, 500 features) - NIPS challenge'}, | |
| {'name': 'wilt', 'version': 2, 'description': 'Wilt Detection (4839 samples, 5 features) - Imbalanced'}, | |
| {'name': 'segment', 'version': 1, 'description': 'Image Segmentation (2310 samples, 19 features) - Multi-class'}, | |
| {'name': 'steel-plates-fault', 'version': 1, 'description': 'Steel Plates Fault (1941 samples, 27 features) - Multi-class'}, | |
| {'name': 'SpeedDating', 'version': 1, 'description': 'Speed Dating (8378 samples, 120 features) - Many features'}, | |
| {'name': 'ozone-level-8hr', 'version': 1, 'description': 'Ozone Level (2534 samples, 72 features) - Imbalanced'}, | |
| ] | |
| def load_openml_dataset(name: str, version: int = 1) -> Tuple[np.ndarray, np.ndarray]: | |
| """Load and preprocess an OpenML dataset.""" | |
| print(f" Loading {name}...", end=" ", flush=True) | |
| try: | |
| data = fetch_openml(name=name, version=version, as_frame=True, parser='auto') | |
| X = data.data | |
| y = data.target | |
| # Encode categorical features | |
| for col in X.columns: | |
| if X[col].dtype == 'object' or X[col].dtype.name == 'category': | |
| X[col] = LabelEncoder().fit_transform(X[col].astype(str)) | |
| X = X.values.astype(np.float32) | |
| X = SimpleImputer(strategy='median').fit_transform(X) | |
| y = LabelEncoder().fit_transform(y) | |
| print(f"OK (shape={X.shape}, classes={len(np.unique(y))})") | |
| return X, y | |
| except Exception as e: | |
| print(f"FAILED: {e}") | |
| return None, None | |
| # ============================================================================= | |
| # RoRA-Tab Module with Bottleneck Option | |
| # ============================================================================= | |
| class RoRATabModuleWithBottleneck(nn.Module): | |
| """RoRA-Tab with optional bottleneck layer.""" | |
| def __init__( | |
| self, | |
| input_dim: int, | |
| hidden_dim: int, | |
| num_classes: int, | |
| n_layers: int, | |
| rora_rank: int, | |
| bottleneck_dim: Optional[int] = None, | |
| init_scale: float = 0.01 | |
| ): | |
| super().__init__() | |
| self.use_bottleneck = bottleneck_dim is not None | |
| if self.use_bottleneck: | |
| self.compress = nn.Linear(input_dim, bottleneck_dim) | |
| self.expand = nn.Linear(bottleneck_dim, hidden_dim) | |
| else: | |
| self.proj = nn.Linear(input_dim, hidden_dim) | |
| self.transformer = RandomTransformerEncoder( | |
| dim=hidden_dim, n_layers=n_layers, n_heads=4, mlp_ratio=4 | |
| ) | |
| self.rora = RoRALayer(hidden_dim, rora_rank, init_scale=init_scale) | |
| self.classifier = nn.Linear(hidden_dim, num_classes) | |
| self._freeze_non_rora() | |
| def _freeze_non_rora(self): | |
| """Freeze all parameters except RoRA.""" | |
| for name, param in self.named_parameters(): | |
| if 'rora' not in name: | |
| param.requires_grad = False | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| if self.use_bottleneck: | |
| x = torch.relu(self.compress(x)) | |
| x = self.expand(x) | |
| else: | |
| x = self.proj(x) | |
| x = self.transformer(x) | |
| x = self.rora(x) | |
| return self.classifier(x) | |
| # ============================================================================= | |
| # TRAINING UTILITIES | |
| # ============================================================================= | |
| def train_rora_model( | |
| model: nn.Module, | |
| X_train: np.ndarray, | |
| y_train: np.ndarray, | |
| X_val: np.ndarray, | |
| y_val: np.ndarray, | |
| epochs: int = 500, | |
| lr: float = 5e-4, | |
| batch_size: int = 32, | |
| patience: int = 50, | |
| weight_decay: float = 1e-4 | |
| ) -> Tuple[nn.Module, float]: | |
| """Train RoRA model with early stopping on validation set.""" | |
| model = model.to(device) | |
| train_dataset = TensorDataset( | |
| torch.tensor(X_train, dtype=torch.float32), | |
| torch.tensor(y_train, dtype=torch.long) | |
| ) | |
| train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) | |
| optimizer = torch.optim.AdamW( | |
| filter(lambda p: p.requires_grad, model.parameters()), | |
| lr=lr, | |
| weight_decay=weight_decay | |
| ) | |
| scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=lr/10) | |
| criterion = nn.CrossEntropyLoss() | |
| best_acc = 0 | |
| best_state = None | |
| patience_counter = 0 | |
| for epoch in range(epochs): | |
| model.train() | |
| for X_batch, y_batch in train_loader: | |
| X_batch, y_batch = X_batch.to(device), y_batch.to(device) | |
| optimizer.zero_grad() | |
| loss = criterion(model(X_batch), y_batch) | |
| loss.backward() | |
| optimizer.step() | |
| scheduler.step() | |
| # Validation | |
| model.eval() | |
| with torch.no_grad(): | |
| X_val_t = torch.tensor(X_val, dtype=torch.float32).to(device) | |
| preds = model(X_val_t).argmax(dim=1).cpu().numpy() | |
| acc = (preds == y_val).mean() | |
| if acc > best_acc: | |
| best_acc = acc | |
| best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()} | |
| patience_counter = 0 | |
| else: | |
| patience_counter += 1 | |
| if patience_counter >= patience: | |
| break | |
| # Restore best model | |
| if best_state is not None: | |
| model.load_state_dict(best_state) | |
| model.to(device) | |
| return model, best_acc | |
| def predict_with_model(model: nn.Module, X: np.ndarray) -> np.ndarray: | |
| """Get predictions from trained model.""" | |
| model.eval() | |
| with torch.no_grad(): | |
| X_t = torch.tensor(X, dtype=torch.float32).to(device) | |
| preds = model(X_t).argmax(dim=1).cpu().numpy() | |
| return preds | |
| # ============================================================================= | |
| # RORA CONFIGURATIONS | |
| # ============================================================================= | |
| RORA_CONFIGS = [ | |
| {'name': 'shallow_no_bn', 'xgb_depth': 1, 'bottleneck': None}, | |
| {'name': 'shallow_bn32', 'xgb_depth': 1, 'bottleneck': 32}, | |
| {'name': 'medium_no_bn', 'xgb_depth': 3, 'bottleneck': None}, | |
| {'name': 'medium_bn64', 'xgb_depth': 3, 'bottleneck': 64}, | |
| {'name': 'deep_no_bn', 'xgb_depth': 4, 'bottleneck': None}, | |
| {'name': 'deep_bn64', 'xgb_depth': 4, 'bottleneck': 64}, | |
| ] | |
| def get_model_probs(model: nn.Module, X: np.ndarray) -> np.ndarray: | |
| """Get probability predictions from model.""" | |
| model.eval() | |
| with torch.no_grad(): | |
| X_t = torch.tensor(X, dtype=torch.float32).to(device) | |
| logits = model(X_t) | |
| probs = torch.softmax(logits, dim=1).cpu().numpy() | |
| return probs | |
| def train_all_rora_configs( | |
| X_train_scaled: np.ndarray, | |
| y_train: np.ndarray, | |
| X_test_scaled: np.ndarray, | |
| y_test: np.ndarray, | |
| seed: int = 42 | |
| ) -> Dict: | |
| """ | |
| Train all 6 RoRA configs to convergence and return individual + ensemble results. | |
| Returns dict with: | |
| - Individual config accuracies | |
| - Ensemble accuracy | |
| - All probability predictions for ensemble | |
| """ | |
| n_classes = len(np.unique(y_train)) | |
| results = {} | |
| all_test_probs = [] | |
| for config in RORA_CONFIGS: | |
| config_name = config['name'] | |
| print(f" {config_name}...", end=" ", flush=True) | |
| # Train XGBoost on full training data | |
| xgb_model = xgb.XGBClassifier( | |
| n_estimators=200, | |
| max_depth=config['xgb_depth'], | |
| learning_rate=0.1, | |
| random_state=seed, | |
| use_label_encoder=False, | |
| eval_metric='mlogloss' | |
| ) | |
| xgb_model.fit(X_train_scaled, y_train) | |
| # Get leaf embeddings | |
| all_train_leaves = xgb_model.apply(X_train_scaled).astype(np.float32) | |
| test_leaves = xgb_model.apply(X_test_scaled).astype(np.float32) | |
| # Scale leaves | |
| leaf_scaler = StandardScaler() | |
| all_train_leaves = leaf_scaler.fit_transform(all_train_leaves) | |
| test_leaves = leaf_scaler.transform(test_leaves) | |
| # Split for early stopping (90/10 stratified) | |
| train_idx, val_idx = train_test_split( | |
| np.arange(len(all_train_leaves)), test_size=0.10, | |
| random_state=seed, stratify=y_train | |
| ) | |
| train_leaves = all_train_leaves[train_idx] | |
| val_leaves = all_train_leaves[val_idx] | |
| y_tr = y_train[train_idx] | |
| y_val = y_train[val_idx] | |
| # Create and train model | |
| model = RoRATabModuleWithBottleneck( | |
| input_dim=train_leaves.shape[1], | |
| bottleneck_dim=config['bottleneck'], | |
| hidden_dim=768, | |
| num_classes=n_classes, | |
| n_layers=8, | |
| rora_rank=128 | |
| ) | |
| model, _ = train_rora_model( | |
| model, train_leaves, y_tr, val_leaves, y_val | |
| ) | |
| # Get predictions | |
| test_probs = get_model_probs(model, test_leaves) | |
| test_preds = test_probs.argmax(axis=1) | |
| acc = (test_preds == y_test).mean() | |
| results[config_name] = acc | |
| all_test_probs.append(test_probs) | |
| print(f"{acc:.4f}") | |
| # Ensemble: average probabilities | |
| ensemble_probs = np.mean(all_test_probs, axis=0) | |
| ensemble_preds = ensemble_probs.argmax(axis=1) | |
| results['ensemble'] = (ensemble_preds == y_test).mean() | |
| print(f" ensemble: {results['ensemble']:.4f}") | |
| return results | |
| # ============================================================================= | |
| # TUNED XGBOOST BASELINE | |
| # ============================================================================= | |
| def train_tuned_xgboost( | |
| X_train: np.ndarray, | |
| y_train: np.ndarray, | |
| seed: int = 42 | |
| ) -> xgb.XGBClassifier: | |
| """ | |
| Train XGBoost with hyperparameter tuning via validation. | |
| Tests multiple configurations and picks the best. | |
| """ | |
| X_tr, X_val, y_tr, y_val = train_test_split( | |
| X_train, y_train, test_size=0.2, random_state=seed, stratify=y_train | |
| ) | |
| scaler = StandardScaler() | |
| X_tr_scaled = scaler.fit_transform(X_tr) | |
| X_val_scaled = scaler.transform(X_val) | |
| param_grid = [ | |
| {'n_estimators': 200, 'max_depth': 1, 'learning_rate': 0.1}, | |
| {'n_estimators': 200, 'max_depth': 2, 'learning_rate': 0.1}, | |
| {'n_estimators': 100, 'max_depth': 3, 'learning_rate': 0.1}, | |
| {'n_estimators': 200, 'max_depth': 4, 'learning_rate': 0.1}, | |
| {'n_estimators': 300, 'max_depth': 5, 'learning_rate': 0.05}, | |
| {'n_estimators': 200, 'max_depth': 6, 'learning_rate': 0.1}, | |
| {'n_estimators': 500, 'max_depth': 4, 'learning_rate': 0.05}, | |
| {'n_estimators': 100, 'max_depth': 8, 'learning_rate': 0.1}, | |
| ] | |
| best_params = None | |
| best_acc = 0 | |
| for params in param_grid: | |
| model = xgb.XGBClassifier( | |
| **params, | |
| random_state=seed, | |
| use_label_encoder=False, | |
| eval_metric='mlogloss' | |
| ) | |
| model.fit(X_tr_scaled, y_tr) | |
| acc = (model.predict(X_val_scaled) == y_val).mean() | |
| if acc > best_acc: | |
| best_acc = acc | |
| best_params = params.copy() | |
| return best_params | |
| # ============================================================================= | |
| # FOLD EVALUATION | |
| # ============================================================================= | |
| def evaluate_fold( | |
| X_train: np.ndarray, | |
| y_train: np.ndarray, | |
| X_test: np.ndarray, | |
| y_test: np.ndarray, | |
| seed: int, | |
| fold: int, | |
| dataset_name: str, | |
| quick_mode: bool = False, | |
| use_cache: bool = True, | |
| force_recompute: List[str] = None | |
| ) -> Dict: | |
| """ | |
| Evaluate all methods on a single fold. | |
| Args: | |
| force_recompute: List of methods to recompute even if cached. | |
| Options: ['rora', 'xgb', 'tabicl', 'tabpfn'] | |
| """ | |
| if force_recompute is None: | |
| force_recompute = [] | |
| results = {'fold': fold} | |
| n_classes = len(np.unique(y_train)) | |
| # Scale features | |
| scaler = StandardScaler() | |
| X_train_scaled = scaler.fit_transform(X_train) | |
| X_test_scaled = scaler.transform(X_test) | |
| # ------------------------------------------------------------------------- | |
| # 1. RoRA-Tab: Train all configs + ensemble | |
| # ------------------------------------------------------------------------- | |
| cached_rora = None | |
| if use_cache and 'rora' not in force_recompute: | |
| cached_rora = load_cached_result(dataset_name, fold, 'rora', seed) | |
| if cached_rora is not None: | |
| print(f" RoRA-Tab (cached):") | |
| for config in RORA_CONFIGS: | |
| config_name = config['name'] | |
| if config_name in cached_rora: | |
| print(f" {config_name}: {cached_rora[config_name]:.4f}") | |
| results[f'rora_{config_name}'] = cached_rora[config_name] | |
| print(f" ensemble: {cached_rora['ensemble']:.4f}") | |
| results['rora_ensemble'] = cached_rora['ensemble'] | |
| results['rora_time'] = cached_rora['time'] | |
| else: | |
| print(f" Training RoRA-Tab (all configs):") | |
| start = time.time() | |
| rora_results = train_all_rora_configs( | |
| X_train_scaled, y_train, X_test_scaled, y_test, seed | |
| ) | |
| # Store individual config results | |
| for config in RORA_CONFIGS: | |
| config_name = config['name'] | |
| results[f'rora_{config_name}'] = rora_results[config_name] | |
| results['rora_ensemble'] = rora_results['ensemble'] | |
| results['rora_time'] = time.time() - start | |
| print(f" RoRA-Tab total time: {results['rora_time']:.1f}s") | |
| # Cache the result | |
| if use_cache: | |
| cache_data = { | |
| 'ensemble': rora_results['ensemble'], | |
| 'time': results['rora_time'] | |
| } | |
| for config in RORA_CONFIGS: | |
| cache_data[config['name']] = rora_results[config['name']] | |
| save_cached_result(dataset_name, fold, 'rora', seed, cache_data) | |
| # ------------------------------------------------------------------------- | |
| # 2. Tuned XGBoost Baseline | |
| # ------------------------------------------------------------------------- | |
| cached_xgb = None | |
| if use_cache and 'xgb' not in force_recompute: | |
| cached_xgb = load_cached_result(dataset_name, fold, 'xgb', seed) | |
| if cached_xgb is not None: | |
| print(f" XGBoost (tuned): {cached_xgb['acc']:.4f} (cached)") | |
| results['xgb_acc'] = cached_xgb['acc'] | |
| results['xgb_params'] = cached_xgb['params'] | |
| results['xgb_time'] = cached_xgb['time'] | |
| else: | |
| print(f" Training XGBoost (tuned)...", end=" ", flush=True) | |
| start = time.time() | |
| best_xgb_params = train_tuned_xgboost(X_train, y_train, seed) | |
| # Retrain on full training data | |
| xgb_tuned = xgb.XGBClassifier( | |
| **best_xgb_params, | |
| random_state=seed, | |
| use_label_encoder=False, | |
| eval_metric='mlogloss' | |
| ) | |
| xgb_tuned.fit(X_train_scaled, y_train) | |
| xgb_preds = xgb_tuned.predict(X_test_scaled) | |
| results['xgb_acc'] = (xgb_preds == y_test).mean() | |
| results['xgb_time'] = time.time() - start | |
| results['xgb_params'] = str(best_xgb_params) | |
| print(f"{results['xgb_acc']:.4f} ({results['xgb_time']:.1f}s)") | |
| # Cache the result | |
| if use_cache: | |
| save_cached_result(dataset_name, fold, 'xgb', seed, { | |
| 'acc': results['xgb_acc'], | |
| 'params': results['xgb_params'], | |
| 'time': results['xgb_time'] | |
| }) | |
| # ------------------------------------------------------------------------- | |
| # 3. TabICL (Real predictions) | |
| # ------------------------------------------------------------------------- | |
| if TABICL_AVAILABLE and not quick_mode: | |
| cached_tabicl = None | |
| if use_cache and 'tabicl' not in force_recompute: | |
| cached_tabicl = load_cached_result(dataset_name, fold, 'tabicl', seed) | |
| if cached_tabicl is not None: | |
| print(f" TabICL: {cached_tabicl['acc']:.4f} (cached)") | |
| results['tabicl_acc'] = cached_tabicl['acc'] | |
| results['tabicl_time'] = cached_tabicl['time'] | |
| else: | |
| print(f" Training TabICL...", end=" ", flush=True) | |
| start = time.time() | |
| try: | |
| # TabICL has sample limits, subsample if needed | |
| max_train = min(len(X_train), 2048) | |
| if len(X_train) > max_train: | |
| idx = np.random.RandomState(seed).choice(len(X_train), max_train, replace=False) | |
| X_tr_icl, y_tr_icl = X_train[idx], y_train[idx] | |
| else: | |
| X_tr_icl, y_tr_icl = X_train, y_train | |
| tabicl = TabICLClassifier(device='cuda' if torch.cuda.is_available() else 'cpu') | |
| tabicl.fit(X_tr_icl, y_tr_icl) | |
| icl_preds = tabicl.predict(X_test) | |
| results['tabicl_acc'] = (icl_preds == y_test).mean() | |
| results['tabicl_time'] = time.time() - start | |
| print(f"{results['tabicl_acc']:.4f} ({results['tabicl_time']:.1f}s)") | |
| # Cache the result | |
| if use_cache: | |
| save_cached_result(dataset_name, fold, 'tabicl', seed, { | |
| 'acc': results['tabicl_acc'], | |
| 'time': results['tabicl_time'] | |
| }) | |
| except Exception as e: | |
| print(f"FAILED: {e}") | |
| results['tabicl_acc'] = None | |
| results['tabicl_time'] = None | |
| else: | |
| results['tabicl_acc'] = None | |
| results['tabicl_time'] = None | |
| if not TABICL_AVAILABLE: | |
| print(f" TabICL: Not available") | |
| # ------------------------------------------------------------------------- | |
| # 4. TabPFN (Real predictions) | |
| # ------------------------------------------------------------------------- | |
| if TABPFN_AVAILABLE and not quick_mode: | |
| cached_tabpfn = None | |
| if use_cache and 'tabpfn' not in force_recompute: | |
| cached_tabpfn = load_cached_result(dataset_name, fold, 'tabpfn', seed) | |
| if cached_tabpfn is not None: | |
| print(f" TabPFN: {cached_tabpfn['acc']:.4f} (cached)") | |
| results['tabpfn_acc'] = cached_tabpfn['acc'] | |
| results['tabpfn_time'] = cached_tabpfn['time'] | |
| else: | |
| print(f" Training TabPFN...", end=" ", flush=True) | |
| start = time.time() | |
| try: | |
| # TabPFN has strict limits: 1000 train, 100 features, 10 classes | |
| max_train = min(len(X_train), 1000) | |
| max_feat = min(X_train.shape[1], 100) | |
| if len(X_train) > max_train or X_train.shape[1] > max_feat or n_classes > 10: | |
| # Subsample if needed | |
| if len(X_train) > max_train: | |
| idx = np.random.RandomState(seed).choice(len(X_train), max_train, replace=False) | |
| X_tr_pfn, y_tr_pfn = X_train[idx], y_train[idx] | |
| else: | |
| X_tr_pfn, y_tr_pfn = X_train, y_train | |
| if X_tr_pfn.shape[1] > max_feat: | |
| X_tr_pfn = X_tr_pfn[:, :max_feat] | |
| X_test_pfn = X_test[:, :max_feat] | |
| else: | |
| X_test_pfn = X_test | |
| else: | |
| X_tr_pfn, y_tr_pfn = X_train, y_train | |
| X_test_pfn = X_test | |
| if n_classes <= 10: | |
| tabpfn = TabPFNClassifier(device='cuda' if torch.cuda.is_available() else 'cpu') | |
| tabpfn.fit(X_tr_pfn, y_tr_pfn) | |
| pfn_preds = tabpfn.predict(X_test_pfn) | |
| results['tabpfn_acc'] = (pfn_preds == y_test).mean() | |
| results['tabpfn_time'] = time.time() - start | |
| print(f"{results['tabpfn_acc']:.4f} ({results['tabpfn_time']:.1f}s)") | |
| # Cache the result | |
| if use_cache: | |
| save_cached_result(dataset_name, fold, 'tabpfn', seed, { | |
| 'acc': results['tabpfn_acc'], | |
| 'time': results['tabpfn_time'] | |
| }) | |
| else: | |
| print(f"Skipped (>10 classes)") | |
| results['tabpfn_acc'] = None | |
| results['tabpfn_time'] = None | |
| except Exception as e: | |
| print(f"FAILED: {e}") | |
| results['tabpfn_acc'] = None | |
| results['tabpfn_time'] = None | |
| else: | |
| results['tabpfn_acc'] = None | |
| results['tabpfn_time'] = None | |
| if not TABPFN_AVAILABLE: | |
| print(f" TabPFN: Not available") | |
| return results | |
| # ============================================================================= | |
| # MAIN BENCHMARK | |
| # ============================================================================= | |
| def run_benchmark( | |
| datasets: List[Dict], | |
| n_folds: int = 5, | |
| seed: int = 42, | |
| quick_mode: bool = False, | |
| use_cache: bool = True, | |
| force_recompute: List[str] = None | |
| ) -> pd.DataFrame: | |
| """ | |
| Run full benchmark with stratified K-fold CV. | |
| Args: | |
| datasets: List of dataset configs | |
| n_folds: Number of CV folds | |
| seed: Random seed | |
| quick_mode: If True, skip TabICL/TabPFN for faster testing | |
| use_cache: If True, use cached results when available | |
| force_recompute: List of methods to recompute even if cached | |
| Options: ['rora', 'xgb', 'tabicl', 'tabpfn'] | |
| Returns: | |
| DataFrame with all results | |
| """ | |
| if force_recompute is None: | |
| force_recompute = [] | |
| all_results = [] | |
| print("\n" + "=" * 80) | |
| print("RoRA-Tab Comprehensive Benchmark") | |
| print(f"Configuration: {n_folds}-fold Stratified CV, seed={seed}") | |
| print(f"Methods: RoRA-Tab (6 configs + ensemble), XGBoost (tuned), TabICL, TabPFN") | |
| print(f"RoRA configs: {', '.join([c['name'] for c in RORA_CONFIGS])}") | |
| cache_status = "enabled" if use_cache else "disabled" | |
| print(f"Caching: {cache_status} (cache dir: {CACHE_DIR})") | |
| if force_recompute: | |
| print(f"Force recompute: {', '.join(force_recompute)}") | |
| print("=" * 80) | |
| for ds_info in datasets: | |
| name = ds_info['name'] | |
| version = ds_info.get('version', 1) | |
| print(f"\n{'='*80}") | |
| print(f"Dataset: {name}") | |
| print(f"Description: {ds_info.get('description', 'N/A')}") | |
| print(f"{'='*80}") | |
| X, y = load_openml_dataset(name, version) | |
| if X is None: | |
| print(f" Skipping {name} due to loading error") | |
| continue | |
| # Create stratified K-fold | |
| skf = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=seed) | |
| fold_results = [] | |
| for fold, (train_idx, test_idx) in enumerate(skf.split(X, y)): | |
| print(f"\n Fold {fold + 1}/{n_folds}") | |
| print(f" " + "-" * 40) | |
| X_train, X_test = X[train_idx], X[test_idx] | |
| y_train, y_test = y[train_idx], y[test_idx] | |
| result = evaluate_fold( | |
| X_train, y_train, X_test, y_test, | |
| seed=seed + fold, fold=fold, | |
| dataset_name=name, | |
| quick_mode=quick_mode, | |
| use_cache=use_cache, | |
| force_recompute=force_recompute | |
| ) | |
| result['dataset'] = name | |
| fold_results.append(result) | |
| # Aggregate fold results | |
| df_folds = pd.DataFrame(fold_results) | |
| summary = { | |
| 'dataset': name, | |
| 'n_samples': X.shape[0], | |
| 'n_features': X.shape[1], | |
| 'n_classes': len(np.unique(y)), | |
| 'xgb_mean': df_folds['xgb_acc'].mean(), | |
| 'xgb_std': df_folds['xgb_acc'].std(), | |
| } | |
| # Aggregate individual RoRA config results | |
| for config in RORA_CONFIGS: | |
| config_name = config['name'] | |
| col = f'rora_{config_name}' | |
| if col in df_folds.columns: | |
| summary[f'{col}_mean'] = df_folds[col].mean() | |
| summary[f'{col}_std'] = df_folds[col].std() | |
| # Ensemble results | |
| summary['rora_ensemble_mean'] = df_folds['rora_ensemble'].mean() | |
| summary['rora_ensemble_std'] = df_folds['rora_ensemble'].std() | |
| if df_folds['tabicl_acc'].notna().any(): | |
| summary['tabicl_mean'] = df_folds['tabicl_acc'].mean() | |
| summary['tabicl_std'] = df_folds['tabicl_acc'].std() | |
| else: | |
| summary['tabicl_mean'] = None | |
| summary['tabicl_std'] = None | |
| if df_folds['tabpfn_acc'].notna().any(): | |
| summary['tabpfn_mean'] = df_folds['tabpfn_acc'].mean() | |
| summary['tabpfn_std'] = df_folds['tabpfn_acc'].std() | |
| else: | |
| summary['tabpfn_mean'] = None | |
| summary['tabpfn_std'] = None | |
| # Determine winner (use ensemble for RoRA) | |
| methods = {'RoRA-Ensemble': summary['rora_ensemble_mean'], 'XGBoost': summary['xgb_mean']} | |
| if summary['tabicl_mean'] is not None: | |
| methods['TabICL'] = summary['tabicl_mean'] | |
| if summary['tabpfn_mean'] is not None: | |
| methods['TabPFN'] = summary['tabpfn_mean'] | |
| summary['winner'] = max(methods, key=lambda k: methods[k]) | |
| # Find best individual RoRA config | |
| best_config_acc = 0 | |
| best_config_name = None | |
| for config in RORA_CONFIGS: | |
| config_name = config['name'] | |
| acc = summary.get(f'rora_{config_name}_mean', 0) | |
| if acc > best_config_acc: | |
| best_config_acc = acc | |
| best_config_name = config_name | |
| summary['best_rora_config'] = best_config_name | |
| summary['best_rora_config_mean'] = best_config_acc | |
| all_results.append(summary) | |
| # Print fold summary | |
| print(f"\n Summary for {name}:") | |
| print(f" RoRA configs:") | |
| for config in RORA_CONFIGS: | |
| config_name = config['name'] | |
| mean_key = f'rora_{config_name}_mean' | |
| std_key = f'rora_{config_name}_std' | |
| if mean_key in summary: | |
| marker = " *" if config_name == best_config_name else "" | |
| print(f" {config_name}: {summary[mean_key]:.4f} ± {summary[std_key]:.4f}{marker}") | |
| print(f" RoRA-Ensemble: {summary['rora_ensemble_mean']:.4f} ± {summary['rora_ensemble_std']:.4f}") | |
| print(f" XGBoost: {summary['xgb_mean']:.4f} ± {summary['xgb_std']:.4f}") | |
| if summary['tabicl_mean'] is not None: | |
| print(f" TabICL: {summary['tabicl_mean']:.4f} ± {summary['tabicl_std']:.4f}") | |
| if summary['tabpfn_mean'] is not None: | |
| print(f" TabPFN: {summary['tabpfn_mean']:.4f} ± {summary['tabpfn_std']:.4f}") | |
| print(f" Winner: {summary['winner']}") | |
| return pd.DataFrame(all_results) | |
| def print_final_summary(df: pd.DataFrame): | |
| """Print final comparison table.""" | |
| print("\n\n" + "=" * 120) | |
| print("FINAL RESULTS SUMMARY") | |
| print("=" * 120) | |
| # Count wins | |
| wins = df['winner'].value_counts() | |
| # Main comparison table (Ensemble vs baselines) | |
| print(f"\n{'Dataset':<25} {'RoRA-Ens':<16} {'Best-Cfg':<16} {'XGBoost':<16} {'TabICL':<16} {'TabPFN':<16} {'Winner':<14}") | |
| print("-" * 120) | |
| for _, row in df.iterrows(): | |
| ens_str = f"{row['rora_ensemble_mean']:.4f}±{row['rora_ensemble_std']:.4f}" | |
| best_cfg_str = f"{row['best_rora_config_mean']:.4f}" | |
| xgb_str = f"{row['xgb_mean']:.4f}±{row['xgb_std']:.4f}" | |
| if pd.notna(row.get('tabicl_mean')): | |
| icl_str = f"{row['tabicl_mean']:.4f}±{row['tabicl_std']:.4f}" | |
| else: | |
| icl_str = "N/A" | |
| if pd.notna(row.get('tabpfn_mean')): | |
| pfn_str = f"{row['tabpfn_mean']:.4f}±{row['tabpfn_std']:.4f}" | |
| else: | |
| pfn_str = "N/A" | |
| print(f"{row['dataset']:<25} {ens_str:<16} {best_cfg_str:<16} {xgb_str:<16} {icl_str:<16} {pfn_str:<16} {row['winner']:<14}") | |
| print("-" * 120) | |
| # Win counts | |
| print(f"\n{'Method':<20} {'Wins':<10} {'Win Rate':<15}") | |
| print("-" * 45) | |
| total = len(df) | |
| for method in ['RoRA-Ensemble', 'XGBoost', 'TabICL', 'TabPFN']: | |
| count = wins.get(method, 0) | |
| rate = count / total * 100 | |
| print(f"{method:<20} {count:<10} {rate:.1f}%") | |
| # Average performance | |
| print(f"\n{'Method':<20} {'Mean Accuracy':<15}") | |
| print("-" * 35) | |
| print(f"{'RoRA-Ensemble':<20} {df['rora_ensemble_mean'].mean():.4f}") | |
| print(f"{'XGBoost':<20} {df['xgb_mean'].mean():.4f}") | |
| if df['tabicl_mean'].notna().any(): | |
| print(f"{'TabICL':<20} {df['tabicl_mean'].dropna().mean():.4f}") | |
| if df['tabpfn_mean'].notna().any(): | |
| print(f"{'TabPFN':<20} {df['tabpfn_mean'].dropna().mean():.4f}") | |
| # Individual config breakdown | |
| print("\n\n" + "=" * 120) | |
| print("RORA CONFIG BREAKDOWN (Mean accuracy per config)") | |
| print("=" * 120) | |
| config_names = [c['name'] for c in RORA_CONFIGS] | |
| header = f"{'Dataset':<25}" + "".join([f"{name:<14}" for name in config_names]) + f"{'Ensemble':<14}" | |
| print(header) | |
| print("-" * 120) | |
| for _, row in df.iterrows(): | |
| line = f"{row['dataset']:<25}" | |
| for config in RORA_CONFIGS: | |
| config_name = config['name'] | |
| mean_key = f'rora_{config_name}_mean' | |
| if mean_key in row and pd.notna(row[mean_key]): | |
| line += f"{row[mean_key]:<14.4f}" | |
| else: | |
| line += f"{'N/A':<14}" | |
| line += f"{row['rora_ensemble_mean']:<14.4f}" | |
| print(line) | |
| print("-" * 120) | |
| # Average per config | |
| avg_line = f"{'AVERAGE':<25}" | |
| for config in RORA_CONFIGS: | |
| config_name = config['name'] | |
| mean_key = f'rora_{config_name}_mean' | |
| if mean_key in df.columns: | |
| avg_line += f"{df[mean_key].mean():<14.4f}" | |
| else: | |
| avg_line += f"{'N/A':<14}" | |
| avg_line += f"{df['rora_ensemble_mean'].mean():<14.4f}" | |
| print(avg_line) | |
| def main(): | |
| """Main entry point.""" | |
| import argparse | |
| parser = argparse.ArgumentParser(description='RoRA-Tab Benchmark') | |
| parser.add_argument('--folds', type=int, default=5, help='Number of CV folds') | |
| parser.add_argument('--seed', type=int, default=42, help='Random seed') | |
| parser.add_argument('--quick', action='store_true', help='Quick mode (skip TabICL/TabPFN)') | |
| parser.add_argument('--output', type=str, default='results/benchmark_cv_results.csv', | |
| help='Output CSV path') | |
| parser.add_argument('--no-cache', action='store_true', help='Disable result caching') | |
| parser.add_argument('--recompute', type=str, nargs='+', default=[], | |
| choices=['rora', 'xgb', 'tabicl', 'tabpfn', 'all'], | |
| help='Force recompute specific methods even if cached') | |
| parser.add_argument('--clear-cache', action='store_true', help='Clear all cached results and exit') | |
| parser.add_argument('--clear-cache-dataset', type=str, default=None, | |
| help='Clear cached results for a specific dataset and exit') | |
| args = parser.parse_args() | |
| # Handle cache clearing | |
| if args.clear_cache: | |
| clear_cache() | |
| return None | |
| if args.clear_cache_dataset: | |
| clear_cache(args.clear_cache_dataset) | |
| return None | |
| # Handle 'all' in recompute | |
| force_recompute = args.recompute | |
| if 'all' in force_recompute: | |
| force_recompute = ['rora', 'xgb', 'tabicl', 'tabpfn'] | |
| # Run benchmark | |
| df = run_benchmark( | |
| datasets=DATASETS, | |
| n_folds=args.folds, | |
| seed=args.seed, | |
| quick_mode=args.quick, | |
| use_cache=not args.no_cache, | |
| force_recompute=force_recompute | |
| ) | |
| # Print summary | |
| print_final_summary(df) | |
| # Save results | |
| import os | |
| os.makedirs(os.path.dirname(args.output) if os.path.dirname(args.output) else '.', exist_ok=True) | |
| df.to_csv(args.output, index=False) | |
| print(f"\nResults saved to {args.output}") | |
| return df | |
| if __name__ == "__main__": | |
| results = main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment