Skip to content

Instantly share code, notes, and snippets.

@Proteusiq
Created December 19, 2025 19:42
Show Gist options
  • Select an option

  • Save Proteusiq/4ac57688e411677b7ab890d396a522e7 to your computer and use it in GitHub Desktop.

Select an option

Save Proteusiq/4ac57688e411677b7ab890d396a522e7 to your computer and use it in GitHub Desktop.
fire torch
from functools import wraps
from pathlib import Path
from typing import Any, Literal, NamedTuple
import joblib
import numpy as np
import polars as pl
from hummingbird.ml import convert, load
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import LabelEncoder
type FrameType = Literal["pytorch", "onnx"]
type DeviceType = Literal["cpu", "cuda", "mps"]
class TrainedModel(NamedTuple):
"""Container for trained pipeline and label encoder.
Attributes:
pipeline: Fitted sklearn Pipeline with preprocessing and estimator.
label_encoder: Fitted LabelEncoder for target variable.
"""
pipeline: Pipeline
label_encoder: LabelEncoder
class HybridPipeline:
"""Pipeline combining sklearn preprocessing with a PyTorch estimator.
Hummingbird cannot convert all sklearn transformers (e.g., TableVectorizer).
This class keeps incompatible preprocessing in sklearn while running the
final estimator in PyTorch.
Attributes:
preprocessor: sklearn Pipeline for feature transformation.
pytorch_estimator: Hummingbird-converted PyTorch model.
device: Compute device for PyTorch inference.
label_encoder: Optional encoder for decoding predictions.
"""
def __init__(
self,
preprocessor: Pipeline,
pytorch_estimator: Any,
device: DeviceType,
label_encoder: LabelEncoder | None = None,
) -> None:
self.preprocessor = preprocessor
self.pytorch_estimator = pytorch_estimator
self.device = device
self.label_encoder = label_encoder
def predict(self, X: pl.DataFrame | np.ndarray) -> np.ndarray:
"""Transform features and predict using PyTorch estimator.
Args:
X: Input features (DataFrame or array-like).
Returns:
Predictions, decoded to original labels if label_encoder exists.
"""
X_transformed = self.preprocessor.transform(X)
predictions = self.pytorch_estimator.predict(X_transformed)
if self.label_encoder is not None:
return self.label_encoder.inverse_transform(predictions)
return predictions
def predict_proba(self, X: pl.DataFrame | np.ndarray) -> np.ndarray:
"""Transform features and predict class probabilities.
Args:
X: Input features (DataFrame or array-like).
Returns:
Class probability matrix.
"""
X_transformed = self.preprocessor.transform(X)
return self.pytorch_estimator.predict_proba(X_transformed)
@property
def classes_(self) -> np.ndarray | None:
"""Original class labels from label encoder."""
if self.label_encoder is not None:
return self.label_encoder.classes_
return None
def save(self, path: str) -> None:
"""Save pipeline components to disk.
Args:
path: Directory path for saving model artifacts.
"""
base_path = Path(path)
base_path.mkdir(parents=True, exist_ok=True)
joblib.dump(self.preprocessor, base_path / "preprocessor.joblib")
if self.label_encoder is not None:
joblib.dump(self.label_encoder, base_path / "label_encoder.joblib")
try:
self.pytorch_estimator.save(str(base_path / "estimator"))
except Exception as e:
raise RuntimeError(f"Failed to save PyTorch estimator: {e}") from e
@classmethod
def load(cls, path: str, device: DeviceType = "cpu") -> "HybridPipeline":
"""Load a saved HybridPipeline from disk.
Args:
path: Directory containing saved model artifacts.
device: Target device for PyTorch estimator.
Returns:
Reconstructed HybridPipeline instance.
"""
base_path = Path(path)
preprocessor = joblib.load(base_path / "preprocessor.joblib")
pytorch_estimator = load(str(base_path / "estimator.zip"), override_flag=True)
pytorch_estimator.to(device)
label_encoder_path = base_path / "label_encoder.joblib"
label_encoder = (
joblib.load(label_encoder_path) if label_encoder_path.exists() else None
)
return cls(preprocessor, pytorch_estimator, device, label_encoder)
def __repr__(self) -> str:
"""String representation of the HybridPipeline."""
n_steps = len(self.preprocessor.steps)
encoder_info = "with label encoding" if self.label_encoder else "no label encoding"
return f"HybridPipeline({n_steps} preprocessing steps, {encoder_info}, device={self.device})"
def torchit(
framework: FrameType = "pytorch",
device: DeviceType = "cpu",
):
"""Decorator to convert sklearn models to PyTorch via Hummingbird.
For pipelines, only the final estimator is converted. Preprocessing steps
remain in sklearn and are wrapped in a HybridPipeline.
Args:
framework: Target framework ("pytorch" or "onnx").
device: Compute device ("cpu", "cuda", "mps").
Returns:
Decorator that wraps training functions.
Example:
@torchit(device="cuda")
def train(X, y) -> TrainedModel:
pipeline = make_pipeline(StandardScaler(), LogisticRegression())
pipeline.fit(X, y_encoded)
return TrainedModel(pipeline=pipeline, label_encoder=encoder)
"""
def decorator(func):
@wraps(func)
def wrapper(*args, **kwargs):
result = func(*args, **kwargs)
if isinstance(result, TrainedModel):
sklearn_model = result.pipeline
label_encoder = result.label_encoder
else:
sklearn_model = result
label_encoder = None
if isinstance(sklearn_model, Pipeline):
*preprocessing_steps, (_, estimator) = sklearn_model.steps
pytorch_estimator = convert(estimator, framework)
pytorch_estimator.to(device)
return HybridPipeline(
Pipeline(preprocessing_steps),
pytorch_estimator,
device,
label_encoder,
)
pytorch_model = convert(sklearn_model, framework)
pytorch_model.to(device)
return pytorch_model
return wrapper
return decorator
if __name__ == "__main__":
import polars as pl
from sklearn.impute import SimpleImputer
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report
from sklearn.model_selection import train_test_split
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from skrub import TableVectorizer
URL: str = (
"https://raw.githubusercontent.com/mwaskom/seaborn-data/master/penguins.csv"
)
dataf = pl.read_csv(URL)
TARGET = "species"
X = dataf.select(pl.exclude(TARGET))
y = dataf.get_column(TARGET).to_numpy()
X_train, X_test, y_train, y_test = train_test_split(
X,
y,
test_size=0.2,
random_state=42,
stratify=y,
)
@torchit(device="mps")
def train_classifier(X_train, y_train) -> TrainedModel:
"""Train a penguin species classifier.
Args:
X_train: Training features.
y_train: Training labels (string species names).
Returns:
TrainedModel with fitted pipeline and label encoder.
"""
label_encoder = LabelEncoder()
y_encoded = label_encoder.fit_transform(y_train)
pipeline = make_pipeline(
TableVectorizer(),
SimpleImputer(add_indicator=True),
StandardScaler(),
LogisticRegression(max_iter=1000),
)
pipeline.fit(X_train, y_encoded)
return TrainedModel(pipeline=pipeline, label_encoder=label_encoder)
classifier = train_classifier(X_train, y_train)
print(f"Trained: {classifier}")
classifier.save("penguins_hb")
print("Classes:", classifier.classes_)
print(classification_report(y_test, classifier.predict(X_test)))
classifier_reloaded = HybridPipeline.load("penguins_hb", device="mps")
print("\nReloaded from disk:")
print(f"Loaded: {classifier_reloaded}")
print("Classes:", classifier_reloaded.classes_)
print(classification_report(y_test, classifier_reloaded.predict(X_test)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment