Skip to content

Instantly share code, notes, and snippets.

@betatim
Created February 9, 2026 07:29
Show Gist options
  • Select an option

  • Save betatim/300d890a8f8bfa2d61ed0f5fdc1f8aa4 to your computer and use it in GitHub Desktop.

Select an option

Save betatim/300d890a8f8bfa2d61ed0f5fdc1f8aa4 to your computer and use it in GitHub Desktop.
Agent support dcouments for array API work in scikit-learn

Array API Architecture

Created: 2026-01-07 Last Updated: 2026-01-07

Overview

Scikit-learn's Array API support enables estimators and functions to work with arrays from different libraries (NumPy, CuPy, PyTorch) without modification. This allows computations to run on GPUs when using GPU-backed array libraries.

The implementation follows the Array API Standard, a specification that defines a common API for array manipulation libraries.

Architecture Diagram

flowchart TB
    subgraph user [User Code]
        UserInput["X, y (torch/cupy/numpy)"]
    end

    subgraph config [Configuration]
        ConfigDispatch["array_api_dispatch=True"]
        SciPyEnv["SCIPY_ARRAY_API=1"]
    end

    subgraph sklearn [scikit-learn]
        subgraph utils [sklearn.utils._array_api]
            GetNS["get_namespace()"]
            GetNSD["get_namespace_and_device()"]
            MoveTo["move_to()"]
            Helpers["Helper functions"]
        end

        subgraph estimator [Estimator/Function]
            Fit["fit() / transform() / predict()"]
            Tags["__sklearn_tags__()"]
        end

        subgraph externals [sklearn.externals]
            AAC["array_api_compat"]
            AAX["array_api_extra"]
        end
    end

    subgraph backends [Array Backends]
        NumPy["numpy"]
        CuPy["cupy"]
        Torch["torch"]
        Strict["array-api-strict"]
    end

    UserInput --> Fit
    ConfigDispatch --> GetNS
    SciPyEnv --> GetNS
    Fit --> GetNS
    GetNS --> AAC
    AAC --> backends
    Helpers --> AAX
    MoveTo --> backends
Loading

Configuration Requirements

1. SciPy Array API Environment Variable

Before importing scipy or scikit-learn:

export SCIPY_ARRAY_API=1

This is required because scikit-learn uses scipy for some operations, and scipy needs this flag to enable its own array API support.

2. Enable Array API Dispatch

import sklearn
sklearn.set_config(array_api_dispatch=True)

# Or use a context manager:
from sklearn import config_context
with config_context(array_api_dispatch=True):
    model.fit(X_gpu, y_gpu)

Supported Libraries

Library Device Support Notes
NumPy CPU Default, always works
CuPy CUDA GPU Requires CUDA-capable GPU
PyTorch CPU, CUDA, MPS, XPU Most flexible device support
array-api-strict CPU (simulated devices) For testing/development only

Device Considerations

  • PyTorch MPS (Apple Silicon GPU): Limited to float32, some operations fall back to CPU
  • float64 support: Not all devices support float64 (e.g., MPS). Scikit-learn falls back to float32 when needed.

Component Relationships

1. array-api-compat (sklearn/externals/array_api_compat/)

A vendored copy of array-api-compat that provides:

  • Namespace detection via get_namespace(*arrays)
  • Compatibility wrappers that add missing array API functions to libraries
  • Unified interface across numpy, cupy, torch

Key functions used:

  • array_api_compat.get_namespace() - Detect the array library
  • array_api_compat.is_torch_namespace() - Check if namespace is PyTorch
  • array_api_compat.numpy - NumPy wrapped with array API compatibility

2. array-api-extra (sklearn/externals/array_api_extra/)

A vendored copy of array-api-extra that provides:

  • Additional array functions not in the standard
  • xpx.at[] for advanced indexing
  • xpx.cov(), xpx.kron(), etc.

Imported as:

from sklearn.externals import array_api_extra as xpx

3. sklearn.utils._array_api

Scikit-learn's internal utilities that build on the vendored libraries:

  • get_namespace() - Wraps array-api-compat, checks config
  • get_namespace_and_device() - Also returns device info
  • move_to() - Transfer arrays between namespaces/devices
  • Helper functions for operations not in array API spec

Namespace Flow

sequenceDiagram
    participant User
    participant Estimator
    participant get_namespace
    participant array_api_compat
    participant Backend

    User->>Estimator: fit(X_gpu, y)
    Estimator->>get_namespace: get_namespace(X_gpu)
    get_namespace->>get_namespace: Check array_api_dispatch config
    get_namespace->>array_api_compat: get_namespace(X_gpu)
    array_api_compat->>array_api_compat: Introspect array type
    array_api_compat-->>get_namespace: xp (torch namespace)
    get_namespace-->>Estimator: (xp, is_array_api=True)
    Estimator->>Backend: xp.mean(X_gpu), xp.linalg.svd(X_gpu)
    Backend-->>Estimator: Results on GPU
    Estimator-->>User: Fitted model with GPU attributes
Loading

Input Conversion Rules

Estimators: "Everything Follows X"

When fitting an estimator with array API inputs:

  1. The namespace and device are determined from X
  2. All other inputs (y, sample_weight, constructor arrays) are converted to match X
  3. Fitted attributes are stored in the same namespace/device as X
  4. predict()/transform() expect inputs from the same namespace/device
# Example: X on GPU, y on CPU -> y converted to GPU
with config_context(array_api_dispatch=True):
    lda.fit(X_cuda, y_numpy)  # y is moved to CUDA
    # lda.coef_ is now a CUDA tensor

Metrics: "Everything Follows y_pred"

When calling scoring functions:

  1. The namespace and device are determined from y_pred
  2. All other inputs (y_true, sample_weight) are converted to match y_pred
  3. Scalar outputs return Python floats
  4. Multi-value outputs return arrays in the same namespace
# Example: y_pred on GPU, y_true on CPU -> y_true converted to GPU
with config_context(array_api_dispatch=True):
    score = accuracy_score(y_true_numpy, y_pred_cuda)
    # Returns Python float, not tensor

When Array API Dispatch is Disabled

When array_api_dispatch=False (the default):

  1. get_namespace() always returns the numpy-compat wrapper
  2. is_array_api_compliant is always False
  3. Scikit-learn expects NumPy arrays and calls np.asarray() on inputs
  4. Some array types (like GPU tensors) may fail or trigger implicit copies

Experimental Status

Array API support is currently experimental. This means:

  • API may change between releases
  • Not all estimators support it
  • Some edge cases may not be fully handled
  • Performance may not always be optimal

Track progress at: scikit-learn/scikit-learn#22352

Array API Support for Estimators

Created: 2026-01-07 Last Updated: 2026-01-07

Overview

This document describes the patterns and principles for adding Array API support to scikit-learn estimators. Follow these guidelines when converting an estimator to support GPU arrays.

Prerequisites for an Estimator

An estimator can be converted to support array API if:

  1. It primarily uses NumPy operations (not Cython)
  2. It doesn't rely on scipy functions that lack array API support
  3. The algorithm can be expressed using array API standard operations

Step-by-Step Implementation Pattern

Step 1: Add Required Imports

from sklearn.utils._array_api import (
    get_namespace,
    get_namespace_and_device,
    device,
    supported_float_dtypes,
    # Add any helper functions you need:
    _average,
    _nanmean,
    _convert_to_numpy,
)

Step 2: Declare Array API Support in Tags

Override __sklearn_tags__() to declare support:

def __sklearn_tags__(self):
    tags = super().__sklearn_tags__()
    tags.array_api_support = True
    return tags

For conditional support (based on parameters):

def __sklearn_tags__(self):
    tags = super().__sklearn_tags__()
    # Only SVD solver supports array API
    tags.array_api_support = self.solver == "svd"
    return tags

Step 3: Get Namespace at Method Entry

At the start of fit(), transform(), predict(), and similar methods:

def fit(self, X, y):
    xp, _ = get_namespace(X)
    # or if you need device info:
    xp, is_array_api, device_ = get_namespace_and_device(X)

Step 4: Use Namespace-Aware Validation

Pass the namespace's dtypes to validation:

X, y = validate_data(
    self, X, y,
    ensure_min_samples=2,
    dtype=[xp.float64, xp.float32]
)

Or use supported_float_dtypes() for device-aware dtype selection:

xp, _, X_device = get_namespace_and_device(X)
X = validate_data(
    self, X,
    reset=False,
    dtype=supported_float_dtypes(xp, X_device),
)

Step 5: Replace NumPy Operations with Namespace Operations

NumPy Array API Equivalent
np.mean(X) xp.mean(X)
np.std(X) xp.std(X)
np.zeros((n, m)) xp.zeros((n, m), device=device(X))
np.ones_like(X) xp.ones_like(X)
np.concatenate([a, b]) xp.concat([a, b])
np.dot(a, b) a @ b or xp.matmul(a, b)
np.sum(X, axis=0) xp.sum(X, axis=0)
np.unique(y) xp.unique_values(y)
np.argsort(X) xp.argsort(X, stable=True)
scipy.linalg.svd(X) xp.linalg.svd(X)

Step 6: Handle Array Creation with Device

Always specify device when creating new arrays:

# Wrong - may create on wrong device
zeros = xp.zeros((n, m))

# Correct - explicitly set device
zeros = xp.zeros((n, m), device=device(X), dtype=X.dtype)

# Also correct - infer from existing array
zeros = xp.zeros_like(X)

Step 7: Handle Conditional NumPy/SciPy Usage

Some operations require different implementations:

def _solve_svd(self, X, y):
    xp, is_array_api_compliant = get_namespace(X)

    if is_array_api_compliant:
        # Use array API linalg
        svd = xp.linalg.svd
    else:
        # Fall back to scipy for NumPy arrays
        svd = scipy.linalg.svd

    _, S, Vt = svd(X, full_matrices=False)

Complete Example: LinearDiscriminantAnalysis

From sklearn/discriminant_analysis.py:

def _solve_svd(self, X, y):
    """SVD solver with array API support."""
    xp, is_array_api_compliant = get_namespace(X)

    if is_array_api_compliant:
        svd = xp.linalg.svd
    else:
        svd = scipy.linalg.svd

    n_samples, _ = X.shape
    n_classes = self.classes_.shape[0]

    self.means_ = _class_means(X, y)  # Also array API aware

    # Centering
    Xc = []
    for idx, group in enumerate(self.classes_):
        Xg = X[y == group]
        Xc.append(Xg - self.means_[idx, :])

    self.xbar_ = self.priors_ @ self.means_
    Xc = xp.concat(Xc, axis=0)

    # Scaling
    std = xp.std(Xc, axis=0)
    std[std == 0] = 1.0
    fac = xp.asarray(1.0 / (n_samples - n_classes), dtype=X.dtype, device=device(X))

    X = xp.sqrt(fac) * (Xc / std)
    _, S, Vt = svd(X, full_matrices=False)

    rank = xp.sum(xp.astype(S > self.tol, xp.int32))
    # ... rest of implementation

Complete Example: StandardScaler

From sklearn/preprocessing/_data.py:

def transform(self, X, copy=None):
    """Perform standardization by centering and scaling."""
    xp, _, X_device = get_namespace_and_device(X)
    check_is_fitted(self)

    copy = copy if copy is not None else self.copy
    X = validate_data(
        self, X,
        reset=False,
        accept_sparse="csr",
        copy=copy,
        dtype=supported_float_dtypes(xp, X_device),
        force_writeable=True,
        ensure_all_finite="allow-nan",
    )

    if sparse.issparse(X):
        # Sparse path (NumPy only)
        if self.with_mean:
            raise ValueError("Cannot center sparse matrices")
        if self.scale_ is not None:
            inplace_column_scale(X, 1 / self.scale_)
    else:
        # Dense path (array API compatible)
        if self.with_mean:
            X -= xp.astype(self.mean_, X.dtype)
        if self.with_std:
            X /= xp.astype(self.scale_, X.dtype)

    return X

def __sklearn_tags__(self):
    tags = super().__sklearn_tags__()
    tags.input_tags.allow_nan = True
    tags.input_tags.sparse = not self.with_mean
    tags.transformer_tags.preserves_dtype = ["float64", "float32"]
    tags.array_api_support = True
    return tags

Fitted Attributes

Fitted attributes that are arrays should:

  1. Be stored in the same namespace as the input X
  2. Be on the same device as the input X
  3. Be accessible after fitting without conversion
def fit(self, X, y):
    xp, _ = get_namespace(X)

    # Compute and store in same namespace
    self.mean_ = xp.mean(X, axis=0)  # Will be torch tensor if X is
    self.classes_ = xp.unique_values(y)

    return self

Converting Fitted Estimators

To move a fitted estimator's attributes to a different format:

from sklearn.utils._array_api import _estimator_with_converted_arrays

# Convert CUDA tensors to NumPy
cupy_to_ndarray = lambda array: array.get()
lda_numpy = _estimator_with_converted_arrays(lda_gpu, cupy_to_ndarray)

Common Pitfalls

1. Forgetting Device on Array Creation

# Wrong
mask = xp.zeros(n, dtype=xp.bool)

# Correct
mask = xp.zeros(n, dtype=xp.bool, device=device(X))

2. Using NumPy Functions Directly

# Wrong
result = np.sum(X)

# Correct
xp, _ = get_namespace(X)
result = xp.sum(X)

3. Implicit Type Coercion

# Wrong - Python float may not broadcast correctly
X = X * 0.5

# Correct - explicit array creation
X = X * xp.asarray(0.5, dtype=X.dtype, device=device(X))

4. Boolean Indexing

Array API has limited boolean indexing support:

# This may not work in all namespaces
X_subset = X[mask]

# Use xp.take with integer indices instead
indices = xp.arange(X.shape[0], device=device(X))
X_subset = xp.take(X, indices[mask], axis=0)

5. Calling Methods That Don't Exist

Not all NumPy methods exist in array API:

# These don't exist in array API:
# - np.nanmean, np.nanmin, np.nanmax
# - np.bincount
# - np.unique (use unique_values, unique_inverse, etc.)

# Use sklearn helpers instead:
from sklearn.utils._array_api import _nanmean, _bincount

Operations Not in Array API Standard

For operations not in the standard, use sklearn's helper functions:

Operation Helper Function
np.nanmean _nanmean(X, axis=None, xp=None)
np.nanmin _nanmin(X, axis=None, xp=None)
np.nanmax _nanmax(X, axis=None, xp=None)
np.bincount _bincount(array, weights=None, minlength=None, xp=None)
np.average _average(a, axis=None, weights=None, xp=None)
np.median _median(x, axis=None, keepdims=False, xp=None)
np.isin _isin(element, test_elements, xp, ...)
scipy.special.logsumexp _logsumexp(array, axis=None, xp=None)
np.fill_diagonal _fill_diagonal(array, value, xp)

Estimators With Partial Support

Some estimators only support array API for certain parameter combinations:

# PCA: array API only with certain solvers
def __sklearn_tags__(self):
    tags = super().__sklearn_tags__()
    solver = getattr(self, "_fit_svd_solver", self.svd_solver)
    tags.array_api_support = solver not in ["arpack", "randomized"] or (
        solver == "randomized" and self.power_iteration_normalizer == "QR"
    )
    return tags

Document these restrictions in the docstring and raise NotImplementedError for unsupported configurations:

def fit(self, X, y):
    xp, is_array_api = get_namespace(X)

    if is_array_api and self.init_params not in ["random", "random_from_data"]:
        raise NotImplementedError(
            f"init_params={self.init_params!r} is not supported with "
            "array_api_dispatch enabled."
        )

Array API Support for Metrics and Functions

Created: 2026-01-07 Last Updated: 2026-01-07

Overview

This document describes the patterns for adding Array API support to scikit-learn metrics and standalone functions. Metrics follow a different convention than estimators: y_pred leads instead of X leads.

The "y_pred Leads" Convention

For scoring functions:

  1. The namespace and device are determined from y_pred
  2. All other inputs (y_true, sample_weight, etc.) are converted to match y_pred
  3. This enables metrics to work within pipelines where X moves between devices

Why y_pred Leads

Consider a pipeline where:

  1. X starts on CPU (with string categorical data)
  2. A transformer encodes and moves X to GPU
  3. A GPU-based estimator makes predictions

The predictions (y_pred) will be on GPU, but y_true may still be on CPU. By following y_pred, the metric automatically handles this mismatch.

Implementation Pattern for Metrics

Step 1: Get Namespace from y_pred

from sklearn.utils._array_api import (
    get_namespace_and_device,
    move_to,
    _average,
    _is_numpy_namespace,
)

def my_metric(y_true, y_pred, *, sample_weight=None):
    # Determine namespace and device from y_pred
    xp, _, device_ = get_namespace_and_device(y_pred)

Step 2: Convert Other Inputs with move_to()

def my_metric(y_true, y_pred, *, sample_weight=None):
    xp, _, device_ = get_namespace_and_device(y_pred)

    # Convert y_true and sample_weight to match y_pred
    y_true, sample_weight = move_to(y_true, sample_weight, xp=xp, device=device_)

Step 3: Perform Computation with Namespace

def my_metric(y_true, y_pred, *, sample_weight=None):
    xp, _, device_ = get_namespace_and_device(y_pred)
    y_true, sample_weight = move_to(y_true, sample_weight, xp=xp, device=device_)

    # Use namespace operations
    correct = y_true == y_pred
    score = xp.sum(correct)

    return float(score)  # Return Python float for scalar

Return Type Conventions

Scalar Outputs: Return Python Float

def accuracy_score(y_true, y_pred, *, normalize=True, sample_weight=None):
    # ... computation ...
    return float(_average(score, weights=sample_weight, normalize=normalize, xp=xp))

Multi-Value Outputs: Return Array in Same Namespace

When a function needs to return multiple values (e.g., per-class scores, a matrix), return a single array in the same namespace as the input:

def confusion_matrix(y_true, y_pred, ...):
    # ... computation ...
    # Returns a 2D array (n_classes, n_classes) in the input namespace
    return xp.asarray(cm, device=device_)

def per_class_scores(y_true, y_pred, ...):
    # ... computation ...
    # Returns a 1D array with one value per class, in the input namespace
    return scores  # shape (n_classes,), in xp namespace

The array contains all the values; callers index into it as needed.

Complete Example: accuracy_score

From sklearn/metrics/_classification.py:

@validate_params(
    {
        "y_true": ["array-like", "sparse matrix"],
        "y_pred": ["array-like", "sparse matrix"],
        "normalize": ["boolean"],
        "sample_weight": ["array-like", None],
    },
    prefer_skip_nested_validation=True,
)
def accuracy_score(y_true, y_pred, *, normalize=True, sample_weight=None):
    """Accuracy classification score."""
    # Step 1: Get namespace from y_pred
    xp, _, device = get_namespace_and_device(y_pred)

    # Step 2: Convert y_true and sample_weight to match
    y_true, sample_weight = move_to(y_true, sample_weight, xp=xp, device=device)

    # Step 3: Validate and prepare
    y_true, y_pred = attach_unique(y_true, y_pred)
    y_type, y_true, y_pred, sample_weight = _check_targets(
        y_true, y_pred, sample_weight
    )

    # Step 4: Compute with namespace operations
    if y_type.startswith("multilabel"):
        differing_labels = _count_nonzero(y_true - y_pred, xp=xp, device=device, axis=1)
        score = xp.asarray(differing_labels == 0, device=device)
    else:
        score = y_true == y_pred

    # Step 5: Return Python float
    return float(_average(score, weights=sample_weight, normalize=normalize, xp=xp))

Complete Example: confusion_matrix

A more complex example showing fallback to NumPy for efficiency:

def confusion_matrix(y_true, y_pred, *, labels=None, sample_weight=None, normalize=None):
    """Compute confusion matrix."""
    # Get namespace for output consistency
    xp, _, device_ = get_namespace_and_device(y_true, y_pred, labels, sample_weight)

    # Validate inputs
    y_true = check_array(y_true, dtype=None, ensure_2d=False, ...)
    y_pred = check_array(y_pred, dtype=None, ensure_2d=False, ...)

    # Convert to NumPy for efficient computation
    # (scipy.sparse.coo_matrix is more efficient for this)
    y_true = _convert_to_numpy(y_true, xp)
    y_pred = _convert_to_numpy(y_pred, xp)
    if sample_weight is not None:
        sample_weight = _convert_to_numpy(sample_weight, xp)

    # ... compute confusion matrix using NumPy/scipy ...

    # Convert result back to input namespace for consistency
    return xp.asarray(cm, device=device_)

Complete Example: balanced_accuracy_score

Shows handling namespace-specific behavior:

def balanced_accuracy_score(y_true, y_pred, *, sample_weight=None, adjusted=False):
    """Compute the balanced accuracy."""
    # Compute confusion matrix (already handles namespace)
    C = confusion_matrix(y_true, y_pred, sample_weight=sample_weight)

    xp, _, device_ = get_namespace_and_device(y_pred, y_true)

    # Handle array-api-strict quirks
    if _is_xp_namespace(xp, "array_api_strict"):
        # array_api_strict only supports floating point for __truediv__
        C = xp.astype(C, _max_precision_float_dtype(xp, device=device_), copy=False)

    # Handle division warnings for NumPy
    context_manager = (
        np.errstate(divide="ignore", invalid="ignore")
        if _is_numpy_namespace(xp)
        else nullcontext()
    )

    with context_manager:
        per_class = xp.linalg.diagonal(C) / xp.sum(C, axis=1)

    # Handle NaN for missing classes
    if xp.any(xp.isnan(per_class)):
        warnings.warn("y_pred contains classes not in y_true")
        per_class = per_class[~xp.isnan(per_class)]

    score = xp.mean(per_class)

    if adjusted:
        n_classes = per_class.shape[0]
        chance = 1 / n_classes
        score -= chance
        score /= 1 - chance

    return float(score)

Handling Helper Functions

When a metric calls internal helper functions, pass the namespace:

def _check_targets(y_true, y_pred, sample_weight=None):
    """Check targets are valid and consistent."""
    xp, _ = get_namespace(y_true, y_pred)
    # ... use xp throughout
    return y_type, y_true, y_pred, sample_weight


def my_metric(y_true, y_pred, *, sample_weight=None):
    xp, _, device_ = get_namespace_and_device(y_pred)
    y_true, sample_weight = move_to(y_true, sample_weight, xp=xp, device=device_)

    # Helper function will get namespace from its inputs
    y_type, y_true, y_pred, sample_weight = _check_targets(
        y_true, y_pred, sample_weight
    )

Pairwise Metrics

For pairwise distance/kernel functions:

def euclidean_distances(X, Y=None, *, Y_norm_squared=None, squared=False, X_norm_squared=None):
    xp, _, device_ = get_namespace_and_device(X, Y)

    # Move Y to match X if needed
    if Y is not None:
        Y = move_to(Y, xp=xp, device=device_)

    # Compute with namespace operations
    # ...

    return distances  # Returns array in same namespace as X

Common Patterns

Pattern 1: Simple Binary Metric

def binary_metric(y_true, y_pred):
    xp, _, device_ = get_namespace_and_device(y_pred)
    y_true = move_to(y_true, xp=xp, device=device_)

    result = xp.sum(y_true == y_pred) / y_true.shape[0]
    return float(result)

Pattern 2: Metric with sample_weight

def weighted_metric(y_true, y_pred, *, sample_weight=None):
    xp, _, device_ = get_namespace_and_device(y_pred)
    y_true, sample_weight = move_to(y_true, sample_weight, xp=xp, device=device_)

    # Use _average helper for weighted averaging
    return float(_average(y_true == y_pred, weights=sample_weight, xp=xp))

Pattern 3: Metric Returning Multiple Values

def multi_value_metric(y_true, y_pred):
    xp, _, device_ = get_namespace_and_device(y_pred)
    y_true = move_to(y_true, xp=xp, device=device_)

    value1 = xp.mean(y_true == y_pred)
    value2 = xp.sum(y_true != y_pred)

    # Return arrays, not floats, for multiple values
    return value1, value2

Pattern 4: Metric with Fallback to NumPy

For efficiency, some metrics compute in NumPy then convert back:

def efficient_metric(y_true, y_pred):
    xp, _, device_ = get_namespace_and_device(y_pred)

    # Convert to NumPy for efficient computation
    y_true_np = _convert_to_numpy(y_true, xp)
    y_pred_np = _convert_to_numpy(y_pred, xp)

    # Use efficient NumPy/SciPy operations
    result_np = scipy_function(y_true_np, y_pred_np)

    # Convert back to original namespace
    return xp.asarray(result_np, device=device_)

Testing Metrics with Array API

See the testing patterns document for details. Key points:

  • Use yield_namespace_device_dtype_combinations() for parametrization
  • Compare results between NumPy and array API implementations
  • Handle floating-point tolerance differences for different dtypes
@pytest.mark.parametrize(
    "array_namespace, device, _",
    yield_namespace_device_dtype_combinations()
)
def test_my_metric_array_api(array_namespace, device, _):
    xp = _array_api_for_tests(array_namespace, device)

    y_true_np = np.array([0, 1, 1, 0])
    y_pred_np = np.array([0, 1, 0, 0])

    y_pred_xp = xp.asarray(y_pred_np, device=device)

    result_np = my_metric(y_true_np, y_pred_np)

    with config_context(array_api_dispatch=True):
        result_xp = my_metric(y_true_np, y_pred_xp)  # y_true stays numpy

    assert result_xp == pytest.approx(result_np)

Array API Testing Patterns

Created: 2026-01-07 Last Updated: 2026-01-26

Overview

This document describes how to write tests for array API support in scikit-learn. Testing ensures that estimators and functions work correctly across different array libraries and devices.

Test Infrastructure

Key Testing Utilities

from sklearn.utils._array_api import (
    yield_namespace_device_dtype_combinations,
    _get_namespace_device_dtype_ids,
    _convert_to_numpy,
    get_namespace,
    device,
)
from sklearn.utils._testing import (
    _array_api_for_tests,
    skip_if_array_api_compat_not_configured,
)
from sklearn.utils.estimator_checks import (
    check_array_api_input,
    check_array_api_input_and_values,
)

Getting a Test Namespace

Use _array_api_for_tests() to get a namespace for testing:

from sklearn.utils._testing import _array_api_for_tests

xp = _array_api_for_tests("torch", "cuda")
X = xp.asarray(X_np, device="cuda")

This function:

  • Skips the test if the library isn't installed
  • Skips if the device isn't available (e.g., no GPU)
  • Returns the namespace module ready for use

Common Estimator Tests

Using check_array_api_input

The check_array_api_input function is the standard way to test estimators:

from sklearn.utils.estimator_checks import check_array_api_input

@pytest.mark.parametrize(
    "array_namespace, device, dtype_name",
    yield_namespace_device_dtype_combinations(),
    ids=_get_namespace_device_dtype_ids,
)
def test_my_estimator_array_api(array_namespace, device, dtype_name):
    estimator = MyEstimator()
    check_array_api_input(
        name="MyEstimator",
        estimator_orig=estimator,
        array_namespace=array_namespace,
        device=device,
        dtype_name=dtype_name,
    )

What check_array_api_input verifies:

  1. Estimator can be fitted with array API inputs
  2. Fitted attributes are in the same namespace as input
  3. Fitted attributes are on the same device as input
  4. All methods (predict, transform, etc.) work with array API inputs
  5. Output types are consistent

Using check_array_api_input_and_values

For stricter testing that also checks numerical results:

from sklearn.utils.estimator_checks import check_array_api_input_and_values

@pytest.mark.parametrize(
    "array_namespace, device, dtype_name",
    yield_namespace_device_dtype_combinations(),
    ids=_get_namespace_device_dtype_ids,
)
@pytest.mark.parametrize(
    "check",
    [check_array_api_input_and_values],
    ids=_get_check_estimator_ids,
)
def test_pca_array_api(check, array_namespace, device, dtype_name):
    estimator = PCA(n_components=2, svd_solver="full")
    check(
        "PCA",
        estimator,
        array_namespace,
        device=device,
        dtype_name=dtype_name,
    )

Additional checks with check_values=True:

  • Numerical results match between NumPy and array API
  • Uses appropriate tolerance based on dtype

Testing with sample_weight

check_array_api_input(
    name="StandardScaler",
    estimator_orig=StandardScaler(),
    array_namespace=array_namespace,
    device=device,
    dtype_name=dtype_name,
    check_sample_weight=True,  # Test sample_weight handling
)

Writing Custom Array API Tests

Basic Pattern

@skip_if_array_api_compat_not_configured
@pytest.mark.parametrize(
    "array_namespace, device, dtype_name",
    yield_namespace_device_dtype_combinations(),
    ids=_get_namespace_device_dtype_ids,
)
def test_my_function_array_api(array_namespace, device, dtype_name):
    xp = _array_api_for_tests(array_namespace, device)

    # Create test data
    X_np = np.array([[1, 2], [3, 4]], dtype=dtype_name)
    y_np = np.array([0, 1])

    # Convert to array API
    X_xp = xp.asarray(X_np, device=device)
    y_xp = xp.asarray(y_np, device=device)

    # Call with NumPy
    result_np = my_function(X_np, y_np)

    # Call with array API
    with config_context(array_api_dispatch=True):
        result_xp = my_function(X_xp, y_xp)

    # Compare results
    result_xp_np = _convert_to_numpy(result_xp, xp=xp)
    assert_allclose(result_np, result_xp_np, rtol=1e-5)

Testing Namespace Preservation

def test_namespace_preserved(array_namespace, device, dtype_name):
    xp = _array_api_for_tests(array_namespace, device)

    X = xp.asarray(np.random.randn(100, 10).astype(dtype_name), device=device)

    with config_context(array_api_dispatch=True):
        scaler = StandardScaler()
        scaler.fit(X)

        # Check fitted attributes are in correct namespace
        assert get_namespace(scaler.mean_)[0] == xp
        assert get_namespace(scaler.scale_)[0] == xp

        # Check device
        assert device(scaler.mean_) == device(X)

Testing Mixed Inputs (move_to behavior)

def test_mixed_inputs(array_namespace, device, dtype_name):
    xp = _array_api_for_tests(array_namespace, device)

    # X on GPU, y on CPU
    X_xp = xp.asarray(np.random.randn(100, 10).astype(dtype_name), device=device)
    y_np = np.array([0, 1] * 50)

    with config_context(array_api_dispatch=True):
        clf = LinearDiscriminantAnalysis()
        clf.fit(X_xp, y_np)  # y should be converted to match X

        # Predictions should be on same device as X
        predictions = clf.predict(X_xp)
        assert device(predictions) == device(X_xp)

Testing Metrics

Pattern for Metric Tests

@pytest.mark.parametrize(
    "array_namespace, device, _",
    yield_namespace_device_dtype_combinations()
)
def test_accuracy_score_array_api(array_namespace, device, _):
    xp = _array_api_for_tests(array_namespace, device)

    y_true_np = np.array([0, 1, 1, 0, 1])
    y_pred_np = np.array([0, 1, 0, 0, 1])

    # y_pred on device, y_true stays on CPU
    y_pred_xp = xp.asarray(y_pred_np, device=device)

    result_np = accuracy_score(y_true_np, y_pred_np)

    with config_context(array_api_dispatch=True):
        result_xp = accuracy_score(y_true_np, y_pred_xp)

    # Metrics return Python floats
    assert isinstance(result_xp, float)
    assert result_xp == pytest.approx(result_np)

Testing with sample_weight

def test_metric_with_sample_weight(array_namespace, device, dtype_name):
    xp = _array_api_for_tests(array_namespace, device)

    y_true = np.array([0, 1, 1, 0])
    y_pred = np.array([0, 1, 0, 0])
    sample_weight = np.array([1.0, 2.0, 1.0, 1.0], dtype=dtype_name)

    y_pred_xp = xp.asarray(y_pred, device=device)
    sample_weight_xp = xp.asarray(sample_weight, device=device)

    result_np = accuracy_score(y_true, y_pred, sample_weight=sample_weight)

    with config_context(array_api_dispatch=True):
        result_xp = accuracy_score(y_true, y_pred_xp, sample_weight=sample_weight_xp)

    assert result_xp == pytest.approx(result_np)

Handling Tolerances

Different dtypes require different tolerances. Use _atol_for_type() for consistent, dtype-appropriate absolute tolerances across the codebase.

The _atol_for_type() Utility (Preferred)

from sklearn.utils._array_api import _atol_for_type

def test_numerical_accuracy(array_namespace, device, dtype_name):
    xp = _array_api_for_tests(array_namespace, device)

    # Use _atol_for_type for absolute tolerance
    # Formula: numpy.finfo(dtype).eps * 1000
    # - float64: ~2.22e-13
    # - float32: ~1.19e-4
    atol = _atol_for_type(dtype_name)

    # Common rtol pattern
    rtol = 1e-5 if dtype_name == "float64" else 1e-4

    assert_allclose(result_np, result_xp_np, rtol=rtol, atol=atol)

Why atol Matters for Comparisons Against Zero

When using assert_allclose(actual, desired, rtol=...), the check is:

|actual - desired| <= atol + rtol * |desired|

Problem: When desired=0 and default atol=0:

  • The check becomes |actual| <= 0
  • Any tiny floating-point noise (e.g., 1e-15) causes failure
  • The relative difference becomes inf (division by zero)

Solution: Always use atol=_atol_for_type(dtype_name) when comparing values that may include exact zeros, such as gradient arrays or covariance matrices.

Recommended Pattern for Gradient Tests

from sklearn.utils._array_api import _atol_for_type

def test_kernel_gradient_array_api(array_namespace, device, dtype_name):
    # ... compute K_grad_np and K_grad_xp_np ...

    rtol = 1e-5 if dtype_name == "float64" else 1e-4
    assert_allclose(K_xp_np, K_np, rtol=rtol)
    # Use atol for gradients since they may contain exact zeros
    assert_allclose(K_grad_xp_np, K_grad_np, rtol=rtol, atol=_atol_for_type(dtype_name))

When to Use Higher Tolerances

Some operations require looser tolerances:

# Complex operations (Bessel functions, optimizers, Cholesky decomposition)
rtol = 1e-4 if dtype_name == "float32" else 1e-10

# Operations with accumulated numerical error
rtol = 1e-3 if dtype_name == "float32" else 1e-6

Testing Unsupported Configurations

Test that appropriate errors are raised:

def test_unsupported_raises(array_namespace, device, dtype_name):
    xp = _array_api_for_tests(array_namespace, device)
    X = xp.asarray(np.random.randn(100, 10).astype(dtype_name), device=device)

    # GaussianMixture with kmeans init doesn't support array API
    gmm = GaussianMixture(init_params="kmeans")

    with config_context(array_api_dispatch=True):
        with pytest.raises(
            NotImplementedError,
            match="init_params.*not implemented.*array_api_dispatch"
        ):
            gmm.fit(X)

Using array-api-strict for Development

The array-api-strict library is useful for development because:

  1. It's pure Python (no GPU required)
  2. It strictly follows the array API standard
  3. It has simulated devices for testing device handling
pip install array-api-strict
# Run tests with array-api-strict
pytest -k "array_api and array_api_strict" 

Skip Decorators

Skip if array_api_compat not configured

from sklearn.utils._testing import skip_if_array_api_compat_not_configured

@skip_if_array_api_compat_not_configured
def test_array_api_feature():
    ...

Skip specific devices

@pytest.mark.parametrize(
    "array_namespace, device, dtype_name",
    yield_namespace_device_dtype_combinations(),
    ids=_get_namespace_device_dtype_ids,
)
def test_feature(array_namespace, device, dtype_name):
    if device == "mps" and dtype_name == "float64":
        pytest.skip("MPS doesn't support float64")

    xp = _array_api_for_tests(array_namespace, device)
    # ... test code

CI/CD Considerations

Running Array API Tests Locally

# Install test dependencies
pip install array-api-strict pytest

# Run array API tests
export SCIPY_ARRAY_API=1
pytest -k "array_api" -v

GPU Testing in CI

GPU tests are expensive and run selectively:

Common Test Patterns

Pattern: Test Fitted Attributes

def test_fitted_attrs_namespace(array_namespace, device, dtype_name):
    xp = _array_api_for_tests(array_namespace, device)
    X = xp.asarray(np.random.randn(100, 10).astype(dtype_name), device=device)

    with config_context(array_api_dispatch=True):
        pca = PCA(n_components=2).fit(X)

        # All array attributes should match input namespace
        for attr in ["components_", "mean_", "singular_values_"]:
            arr = getattr(pca, attr)
            assert get_namespace(arr)[0] == xp
            assert device(arr) == device(X)

Pattern: Test Transform Output

def test_transform_output(array_namespace, device, dtype_name):
    xp = _array_api_for_tests(array_namespace, device)
    X = xp.asarray(np.random.randn(100, 10).astype(dtype_name), device=device)

    with config_context(array_api_dispatch=True):
        pca = PCA(n_components=2).fit(X)
        X_transformed = pca.transform(X)

        # Output should be in same namespace/device
        assert get_namespace(X_transformed)[0] == xp
        assert device(X_transformed) == device(X)
        assert X_transformed.shape == (100, 2)

Pattern: Test Predict Methods

def test_predict_methods(array_namespace, device, dtype_name):
    xp = _array_api_for_tests(array_namespace, device)
    X = xp.asarray(np.random.randn(100, 10).astype(dtype_name), device=device)
    y = xp.asarray(np.array([0, 1] * 50), device=device)

    with config_context(array_api_dispatch=True):
        clf = LinearDiscriminantAnalysis().fit(X, y)

        # Test all prediction methods
        pred = clf.predict(X)
        assert get_namespace(pred)[0] == xp

        proba = clf.predict_proba(X)
        assert get_namespace(proba)[0] == xp

        decision = clf.decision_function(X)
        assert get_namespace(decision)[0] == xp

Pattern: Test Score Returns Float

def test_score_returns_float(array_namespace, device, dtype_name):
    xp = _array_api_for_tests(array_namespace, device)
    X = xp.asarray(np.random.randn(100, 10).astype(dtype_name), device=device)
    y = xp.asarray(np.array([0, 1] * 50), device=device)

    with config_context(array_api_dispatch=True):
        clf = LinearDiscriminantAnalysis().fit(X, y)
        score = clf.score(X, y)

        # Score should be Python float, not tensor
        assert isinstance(score, float)

Debugging Array API Issues

Check Namespace

from sklearn.utils._array_api import get_namespace

arr = some_array
xp, is_compliant = get_namespace(arr)
print(f"Namespace: {xp.__name__}, Compliant: {is_compliant}")

Check Device

from sklearn.utils._array_api import device

arr = some_tensor
print(f"Device: {device(arr)}")

Convert to NumPy for Inspection

from sklearn.utils._array_api import _convert_to_numpy, get_namespace

xp, _ = get_namespace(arr)
arr_np = _convert_to_numpy(arr, xp)
print(arr_np)

Array API Utility Functions Reference

Created: 2026-01-07 Last Updated: 2026-01-08

Overview

This document provides a reference for the utility functions in sklearn/utils/_array_api.py. These functions provide the foundation for array API support in scikit-learn. It is possible that sklearn/utils/_array_api.py contains functions not shown here, always double check as the code is the truth.

Purpose of sklearn/utils/_array_api.py

The purpose of this file is to provide a home for helper functions. This includes general utilities like get_namespace, a customized version of array_api_compat's implementation. It also includes functions like nanmin that exist in one or more array libraries, but are not part of the array API standard.

Adding to sklearn/utils/_array_api.py

When encountering a function that is present in Numpy but not part of the array API standard, then it is time to create an implementation of that function in sklearn/utils/_array_api.py. Examples of this are listed in the "Array Operations Not in Standard" section.

To implement a new function the ideal path is to take the code of that function from Numpy and convert it to use building blocks and idioms which are part of the array API. It is ok to only implement a subset of the functionality that the Numpy version provides, be pragmatic. The function signature should mirror Numpy's. It is ok to miss out some arguments or only support a subset of the possible values for an argument.

Core Functions

get_namespace()

def get_namespace(
    *arrays,
    remove_none=True,
    remove_types=REMOVE_TYPES_DEFAULT,
    xp=None
) -> tuple[ModuleType, bool]

Get the array namespace from input arrays.

Parameters:

  • *arrays: Array objects to inspect
  • remove_none: Whether to ignore None values (default: True)
  • remove_types: Types to ignore, default: (str, list, tuple)
  • xp: Pre-computed namespace to skip inspection

Returns:

  • namespace: The array namespace module (e.g., torch, cupy)
  • is_array_api_compliant: True if dispatch is enabled, False otherwise

Behavior:

  • When array_api_dispatch=False: Always returns (numpy_compat, False)
  • When array_api_dispatch=True: Returns the actual namespace of input arrays
  • Sparse arrays are filtered out
  • Pandas DataFrames/Series are filtered out

Example:

import torch
from sklearn.utils._array_api import get_namespace

X = torch.randn(100, 10)
with config_context(array_api_dispatch=True):
    xp, is_compliant = get_namespace(X)
    # xp is the torch namespace
    # is_compliant is True

get_namespace_and_device()

def get_namespace_and_device(
    *array_list,
    remove_none=True,
    remove_types=REMOVE_TYPES_DEFAULT,
    xp=None
) -> tuple[ModuleType, bool, Device]

Combined function to get namespace and device information.

Parameters: Same as get_namespace()

Returns:

  • namespace: The array namespace module
  • is_array_api_compliant: True if dispatch is enabled
  • device: The device object (e.g., cuda:0, cpu, or None for NumPy)

Example:

X = torch.randn(100, 10, device="cuda")
with config_context(array_api_dispatch=True):
    xp, is_compliant, dev = get_namespace_and_device(X)
    # dev is torch.device('cuda:0')

device()

def device(
    *array_list,
    remove_none=True,
    remove_types=REMOVE_TYPES_DEFAULT
) -> Device | None

Get the device from arrays.

Parameters:

  • *array_list: Arrays to get device from
  • remove_none: Whether to ignore None values
  • remove_types: Types to ignore

Returns:

  • Device object or None if not applicable

Raises:

  • ValueError if arrays are on different devices

Example:

from sklearn.utils._array_api import device

X = torch.randn(10, device="cuda")
dev = device(X)  # Returns cuda:0

move_to()

def move_to(
    *arrays,
    xp,
    device
) -> tuple | Array

Move arrays to a target namespace and device.

Parameters:

  • *arrays: Arrays to move (may include None)
  • xp: Target namespace
  • device: Target device

Returns:

  • Single array if one input, tuple if multiple inputs

Notes:

  • Uses DLPack protocol when available for zero-copy transfer
  • Falls back to NumPy intermediate when DLPack fails
  • Sparse arrays are passed through unchanged (only for NumPy target)
  • None values are passed through unchanged

Example:

from sklearn.utils._array_api import move_to, get_namespace_and_device

X_cuda = torch.randn(100, 10, device="cuda")
y_numpy = np.array([0, 1, 1, 0, ...])

xp, _, dev = get_namespace_and_device(X_cuda)
y_cuda = move_to(y_numpy, xp=xp, device=dev)
# y_cuda is now a torch.Tensor on CUDA

Type and Dtype Functions

supported_float_dtypes()

def supported_float_dtypes(xp, device=None) -> tuple

Get supported floating point dtypes for a namespace/device.

Returns: Tuple of dtypes ordered from highest to lowest precision

Example:

supported_float_dtypes(torch, device="mps")
# Returns (torch.float32,) since MPS doesn't support float64

_max_precision_float_dtype()

def _max_precision_float_dtype(xp, device) -> dtype

Get the highest precision float dtype supported.

Returns: xp.float64 or xp.float32 for devices without float64 support


indexing_dtype()

def indexing_dtype(xp) -> dtype

Get platform-appropriate integer dtype for indexing.

Returns: int32 on 32-bit platforms, int64 on 64-bit


Namespace Checking Functions

_is_numpy_namespace()

def _is_numpy_namespace(xp) -> bool

Check if namespace is NumPy or NumPy-compat.

Example:

xp, _ = get_namespace(np.array([1, 2, 3]))
_is_numpy_namespace(xp)  # True

_is_xp_namespace()

def _is_xp_namespace(xp, name) -> bool

Check if namespace matches a specific library name.

Example:

_is_xp_namespace(xp, "torch")  # True for torch namespace
_is_xp_namespace(xp, "cupy")   # True for cupy namespace

Array Operations Not in Standard

These functions provide implementations of operations not in the array API standard.

_nanmin(), _nanmax(), _nanmean()

def _nanmin(X, axis=None, xp=None) -> Array
def _nanmax(X, axis=None, xp=None) -> Array
def _nanmean(X, axis=None, xp=None) -> Array

NaN-aware reductions. For NumPy, delegates to np.nanmin etc. For other namespaces, implements with masking.


_nansum()

def _nansum(X, axis=None, xp=None, keepdims=False, dtype=None) -> Array

NaN-aware sum.


_average()

def _average(
    a,
    axis=None,
    weights=None,
    normalize=True,
    xp=None
) -> Array

Weighted average, array API compatible version of np.average.

Parameters:

  • a: Input array
  • axis: Axis to average over
  • weights: Weight array (optional)
  • normalize: If True, divide by sum of weights; if False, just sum weighted values

Example:

score = _average(correct, weights=sample_weight, normalize=True, xp=xp)

_median()

def _median(x, axis=None, keepdims=False, xp=None) -> Array

Median computation. Uses torch.quantile for PyTorch, xp.median if available, falls back to NumPy.


_bincount()

def _bincount(array, weights=None, minlength=None, xp=None) -> Array

Array API version of np.bincount. Uses namespace's bincount if available, otherwise converts through NumPy.


_isin()

def _isin(element, test_elements, xp, assume_unique=False, invert=False) -> Array

Check element membership. For NumPy, uses np.isin. For other namespaces, implements with sorting-based algorithm.


_logsumexp()

def _logsumexp(array, axis=None, xp=None) -> Array

Compute log-sum-exp in a numerically stable way. Replacement for scipy.special.logsumexp.


_ravel()

def _ravel(array, xp=None) -> Array

Flatten array. For NumPy, uses np.ravel. For others, uses xp.reshape(array, (-1,)).


Diagonal Operations

_fill_diagonal()

def _fill_diagonal(array, value, xp) -> None

Fill the diagonal of a 2D array in-place.

Note: Modifies array in-place.


_add_to_diagonal()

def _add_to_diagonal(array, value, xp) -> None

Add values to the diagonal of a 2D array in-place.


Conversion Functions

_convert_to_numpy()

def _convert_to_numpy(array, xp) -> np.ndarray

Convert any array to NumPy ndarray.

Handles:

  • PyTorch: Calls .cpu().numpy()
  • CuPy: Calls .get()
  • array-api-strict: Moves to CPU device first
  • Others: Uses np.asarray()

_estimator_with_converted_arrays()

def _estimator_with_converted_arrays(estimator, converter) -> Estimator

Create a new estimator with all array attributes converted.

Parameters:

  • estimator: Fitted estimator
  • converter: Function to convert arrays (e.g., lambda x: x.get() for CuPy)

Example:

from sklearn.utils._array_api import _estimator_with_converted_arrays

# Convert GPU estimator to CPU
lda_cpu = _estimator_with_converted_arrays(
    lda_gpu,
    converter=lambda x: x.cpu().numpy()
)

Array Creation Helpers

_asarray_with_order()

def _asarray_with_order(
    array,
    dtype=None,
    order=None,
    copy=None,
    *,
    xp=None,
    device=None
) -> Array

Create array with memory order support for NumPy, ignored for other namespaces.

Note: The order parameter only affects NumPy arrays. For array API compliance, memory layout is not guaranteed.


Math Functions

_expit()

def _expit(X, xp=None) -> Array

Logistic sigmoid function. For NumPy, uses scipy.special.expit. For others, computes 1 / (1 + exp(-X)).


_xlogy()

def _xlogy(x, y, xp=None) -> Array

Compute x * log(y) with proper handling of x == 0.


Linear Algebra Helpers

_cholesky()

def _cholesky(covariance, xp) -> Array

Cholesky decomposition. For NumPy, uses scipy.linalg.cholesky(lower=True). For others, uses xp.linalg.cholesky.


_linalg_solve()

def _linalg_solve(cov_chol, eye_matrix, xp) -> Array

Solve triangular system. For NumPy, uses scipy.linalg.solve_triangular. For others, uses xp.linalg.solve.


Testing Utilities

yield_namespaces()

def yield_namespaces(include_numpy_namespaces=True) -> Iterator[str]

Yield namespace names for testing: "numpy", "array_api_strict", "cupy", "torch".


yield_namespace_device_dtype_combinations()

def yield_namespace_device_dtype_combinations(
    include_numpy_namespaces=True
) -> Iterator[tuple[str, Device, str]]

Yield (namespace, device, dtype) combinations for comprehensive testing.

Yields:

  • ("numpy", None, None)
  • ("array_api_strict", CPU_DEVICE, "float64")
  • ("array_api_strict", device1, "float32")
  • ("torch", "cpu", "float64")
  • ("torch", "cpu", "float32")
  • ("torch", "cuda", "float64")
  • ("torch", "cuda", "float32")
  • ("torch", "mps", "float32")
  • ("cupy", None, None)

_atol_for_type()

def _atol_for_type(dtype_or_dtype_name) -> float

Get appropriate absolute tolerance for floating point comparisons.

Returns: eps * 1000 for the given dtype


Count Functions

_count_nonzero()

def _count_nonzero(X, axis=None, sample_weight=None, xp=None, device=None) -> Array

Count non-zero elements, optionally weighted. For sparse NumPy arrays, uses efficient sparse implementation.


Set Operations

_union1d()

def _union1d(a, b, xp) -> Array

Union of two 1D arrays.


Miscellaneous

size()

def size(x) -> int

Total number of elements in array: math.prod(x.shape).


_check_array_api_dispatch()

def _check_array_api_dispatch(array_api_dispatch) -> None

Verify array API requirements are met. Checks:

  • SciPy version >= 1.14.0
  • SCIPY_ARRAY_API=1 environment variable is set

Raises: ImportError or RuntimeError if requirements not met.


Best Practices

Use Existing Utility Functions

Always use the utility functions from sklearn/utils/_array_api.py instead of ad-hoc patterns.

Bad: Ad-hoc device extraction

device = getattr(X, "device", None)
result = xp.asarray(value, device=device)

Good: Use the device() utility function

from sklearn.utils._array_api import device as array_device

result = xp.asarray(value, device=array_device(X))

Rationale: The device() function in _array_api.py handles edge cases consistently across all array backends, including NumPy arrays which don't have a .device attribute. Using ad-hoc patterns like getattr(X, "device", None) can lead to inconsistent behavior and misses the benefit of centralized maintenance.

Avoid NumPy Scalar Types in Array Arithmetic

When multiplying arrays by scalars that come from NumPy or SciPy functions (e.g., scipy.special.gamma), convert them to Python float first.

Bad: Using numpy scalar directly

from scipy.special import gamma

coef = (2 ** (1.0 - nu)) / gamma(nu)  # numpy.float64
K = coef * array  # May corrupt device info for array_api_strict

Good: Convert to Python float first

from scipy.special import gamma

coef = float((2 ** (1.0 - nu)) / gamma(nu))  # Python float
K = coef * array  # Preserves device info

Rationale: When a numpy.float64 is multiplied with an array_api_strict array, the result may have a corrupted device attribute (e.g., 'cpu' as a string instead of a proper Device object). Python float scalars broadcast correctly and preserve device information.

Import Placement

Place all imports at the top of the file, not inline within functions.

Bad: Inline imports of standard dependencies

def some_function(X, xp):
    from scipy.spatial.distance import cdist  # Should be at top of file
    return cdist(X, X)

Good: Standard dependencies at top of file

from scipy.spatial.distance import cdist

def some_function(X, xp):
    return cdist(X, X)

Exception: Optional array library dependencies (torch, cupy, array_api_strict) should be imported inline, only when we know the user has them installed. This is typically determined by checking if the input array is from that namespace.

Good: Conditional inline import of optional dependencies

from scipy.spatial.distance import cdist

def _cdist(X, Y, metric="euclidean", xp=None):
    xp, _, device_ = get_namespace_and_device(X, Y, xp=xp)
    
    if _is_numpy_namespace(xp):
        return xp.asarray(cdist(X, Y, metric=metric))
    
    if _is_xp_namespace(xp, "torch"):
        # No need to import torch as it is available as `xp` already
        return xp.cdist(X, Y, p=2)
    
    if _is_xp_namespace(xp, "cupy"):
        from cupyx.scipy.spatial.distance import cdist  # Only when input is cupy
        return cdist(X, Y, metric=metric)

Rationale: Optional dependencies like torch and cupy are not required for scikit-learn to function. Importing them at module load time would cause ImportError for users who don't have them installed. By importing conditionally inside functions, we only attempt the import when we know the dependency is available (because the user passed an array from that library).

Use Functional Form for Matrix Transpose

Use xp.matrix_transpose(L) instead of the .mT attribute for matrix transposition.

Bad: Using .mT attribute

# CuPy arrays don't have .mT attribute
L_transposed = L.mT

Good: Use functional form

L_transposed = xp.matrix_transpose(L)

Rationale: The .mT attribute (matrix transpose) was added in the Array API 2022.12 standard, but CuPy hasn't fully adopted it. CuPy arrays raise AttributeError: 'ndarray' object has no attribute 'mT'. The functional form xp.matrix_transpose() is provided by array_api_compat and works consistently across all backends (NumPy, PyTorch, CuPy, etc.).

Note: For 2D arrays, .mT and .T are equivalent, but .mT specifically transposes only the last two dimensions (important for batched operations). The functional xp.matrix_transpose() has the same semantics as .mT.


Preserve Input Dtype on Limited Devices

Some devices don't support float64 (double precision). Don't hardcode float64 conversions.

Bad: Hardcoded float64 conversion

def process_array(X, xp):
    # This will fail on MPS devices
    X = xp.astype(X, xp.float64)
    return X

Good: Preserve input dtype

def process_array(X, xp):
    # Use X.dtype to stay compatible with devices that lack float64
    if X.dtype != target_array.dtype:
        X = xp.astype(X, target_array.dtype)
    return X

Rationale: Apple's MPS (Metal Performance Shaders) device does not support float64. Attempting to convert to float64 on MPS raises TypeError: Cannot convert a MPS Tensor to float64 dtype. By preserving the input array's dtype, code remains compatible with all devices.

Testing

Running Tests with MPS Device

On Apple Silicon Macs, you can run tests using the MPS (Metal Performance Shaders) device for GPU acceleration:

PYTORCH_ENABLE_MPS_FALLBACK=1 SCIPY_ARRAY_API=1 pytest -v -k "array_api and torch"

Environment variables:

  • PYTORCH_ENABLE_MPS_FALLBACK=1: Allows PyTorch to fall back to CPU for operations not supported on MPS (prevents hard errors)
  • SCIPY_ARRAY_API=1: Enables Array API support in scipy

Note: Without PYTORCH_ENABLE_MPS_FALLBACK=1, tests may fail with errors about unsupported operations. The fallback allows tests to complete while still exercising MPS for supported operations.

Devices Without float64 Support

Some GPU devices do not support double precision (float64) floating point numbers:

Device float64 Support Notes
CPU ✅ Yes Full support
CUDA (NVIDIA) ✅ Yes Full support
MPS (Apple Silicon) ❌ No Only float32 supported
XPU (Intel) ✅ Yes Full support

When writing Array API compatible code:

  1. Never hardcode float64: Use X.dtype to preserve the input array's dtype
  2. Use input dtype for new arrays: When creating new arrays, use dtype=X.dtype and device=array_device(X)
  3. Test with float32 configurations: The test suite includes float32 parameter combinations to catch float64 assumptions

Example of creating compatible arrays:

xp, _, device_ = get_namespace_and_device(X)

# Create array matching input dtype and device
ones = xp.ones(shape, dtype=X.dtype, device=device_)

# Convert scalar to array on correct device with correct dtype  
value = xp.asarray(1.0, dtype=X.dtype, device=device_)

Scipy Dtype Preservation

Several scipy functions (scipy.linalg.cholesky, scipy.linalg.cho_solve, scipy.spatial.distance.pdist, etc.) upcast float32 inputs to float64 for their computations. This causes dtype mismatches when comparing numpy results with array API results (which preserve the input dtype).

The utility functions in _array_api.py (_cholesky, _cho_solve, _cdist, _pdist, _squareform) handle this by casting the result back to the input dtype when the input was floating point:

def _cholesky(covariance, xp):
    if _is_numpy_namespace(xp):
        # scipy may upcast float32 to float64; cast back to preserve input dtype
        result = scipy.linalg.cholesky(covariance, lower=True)
        return result.astype(covariance.dtype, copy=False)
    else:
        return xp.linalg.cholesky(covariance)

Note: The dtype cast-back only happens for floating point inputs. Integer inputs (e.g., from structured data with string features) are not cast back to avoid precision loss.

Array API Work Template

Created: 2026-01-09 Last Updated: 2026-01-09

Overview

This document provides a structured workflow for converting scikit-learn components to support the Array API. It serves as a checklist to ensure important steps aren't missed and work proceeds in the right order.

The workflow is based on lessons learned from the Gaussian Process Regressor conversion (see 2026-01-08-gpr-array-api-conversion.md).

Work Decomposition

Before starting implementation, decompose the work into smaller, incremental pieces:

  1. Identify dependencies: Map out which components depend on others
  2. Start with fundamentals: Convert lower-level components before higher-level ones
  3. Test incrementally: Each component should be fully tested before moving to the next

Example: GPR Conversion

The GPR work followed this pattern:

  1. Utility functions first: Added _cdist, _pdist, _squareform, _cho_solve, _kv to _array_api.py
  2. Kernels second: Converted all kernel classes (RBF, Matern, etc.)
  3. Estimator last: Converted GaussianProcessRegressor once kernels were working

This incremental approach catches issues early and makes debugging easier.


Workflow Phases

Phase 1: Pre-work Validation

Before starting any conversion work:

  • Run the existing test suite to confirm the repository is in a good state
  • Identify the scope of work and create a plan document in agents/plans/
  • Map dependencies between components to determine conversion order
# Run tests for the module you'll be modifying
pytest sklearn/gaussian_process/tests/ -v

Phase 2: Testing Setup

Set up tests that compare Array API results against NumPy:

  • Create test file or add test functions following patterns in array-api-testing.md
  • Use yield_namespace_device_dtype_combinations() for parametrization
  • NumPy results are the source of truth - compare other backends against NumPy
  • Include appropriate tolerances for float32 vs float64
@skip_if_array_api_compat_not_configured
@pytest.mark.parametrize(
    "array_namespace, device, dtype_name",
    yield_namespace_device_dtype_combinations(),
    ids=_get_namespace_device_dtype_ids,
)
def test_component_array_api(array_namespace, device, dtype_name):
    xp = _array_api_for_tests(array_namespace, device)
    
    # Create test data
    X_np = np.array(..., dtype=dtype_name)
    X_xp = xp.asarray(X_np, device=device)
    
    # Run with NumPy (baseline)
    result_np = component(X_np)
    
    # Run with Array API
    with config_context(array_api_dispatch=True):
        result_xp = component(X_xp)
    
    # Compare
    result_xp_np = _convert_to_numpy(result_xp, xp=xp)
    assert_allclose(result_np, result_xp_np, rtol=1e-5)

Phase 3: Implementation

Convert the code using Array API utilities:

  • Use utilities from array-api-utilities.md
  • Add new utility functions to sklearn/utils/_array_api.py as needed
  • Run tests frequently during development
  • Handle device placement correctly (use device() utility, not getattr)

Key implementation patterns:

from sklearn.utils._array_api import (
    get_namespace_and_device,
    move_to,
    _convert_to_numpy,
)

def fit(self, X, y):
    xp, is_array_api, device_ = get_namespace_and_device(X)
    
    # Move y to match X's namespace/device
    y = move_to(y, xp=xp, device=device_)
    
    # Use xp instead of np throughout
    result = xp.sum(X, axis=0)
    
    # Store fitted attributes in the same namespace
    self.mean_ = xp.mean(X, axis=0)

Common pitfalls to avoid:

  • Don't use getattr(X, "device", None) - use the device() utility
  • Don't multiply arrays by NumPy scalars - convert to Python float first
  • Don't hardcode float64 - preserve input dtype for MPS compatibility

Phase 4: Cleanup Pass

After implementation is working:

  • Review for code quality and DRY principles
  • Ensure all imports are at the top of the file (except optional array library imports)
  • Remove any debug code or print statements
  • Check that code follows existing conventions in the codebase
  • Run linter and fix any issues

Phase 5: Performance Testing

Benchmark both execution time AND memory usage:

  • Create or adapt a benchmark script (see benchmark_gpr_performance.py)
  • Use tracemalloc to measure peak memory usage
  • Run Phase 1: Compare current branch vs main with NumPy (dispatch off)
  • Run Phase 2: Compare backends (numpy, torch-cpu, torch-mps)

Performance Requirements

Check Requirement
NumPy regression No significant regression vs main branch
torch-cpu Should be similar to NumPy performance
torch-mps Should be similar to NumPy performance
Memory No significant increase vs main branch

Benchmark Script Pattern

import tracemalloc
import time

def benchmark_operation(func, n_reps=5):
    times = []
    peak_memory = 0
    
    for i in range(n_reps):
        tracemalloc.start()
        start = time.perf_counter()
        try:
            func()
        finally:
            elapsed = time.perf_counter() - start
            _, peak = tracemalloc.get_traced_memory()
            tracemalloc.stop()
        
        times.append(elapsed)
        peak_memory = max(peak_memory, peak)
    
    return np.median(times), peak_memory

Running Benchmarks

# Phase 1: Regression check
python agents/scripts/benchmark_script.py --phase 1 --output branch.json
git stash && git checkout main
python agents/scripts/benchmark_script.py --phase 1 --output main.json
git checkout - && git stash pop
python agents/scripts/benchmark_script.py --compare main.json branch.json

# Phase 2: Backend comparison
SCIPY_ARRAY_API=1 PYTORCH_ENABLE_MPS_FALLBACK=1 \
    python agents/scripts/benchmark_script.py --phase 2 --backend all

Phase 6: Final Cleanup

After performance testing:

  • Address any performance issues discovered during benchmarking
  • Final review pass for code quality
  • Ensure all tests pass across all backends
  • Update plan document with completion status

Documentation Requirements

Throughout the work:

  • Update agents/designs/ with new learnings and best practices
  • Document any new utility functions added to _array_api.py in array-api-utilities.md
  • Note any backend-specific quirks or limitations discovered
  • Record benchmark results in agents/plans/

Key Files Reference

File Purpose
array-api-architecture.md Overview of Array API support architecture
array-api-utilities.md Reference for utility functions in _array_api.py
array-api-testing.md Testing patterns and examples
sklearn/utils/_array_api.py Core utility functions
sklearn/utils/tests/test_array_api.py Tests for utility functions

Checklist Summary

[ ] Phase 1: Pre-work Validation
    [ ] Tests pass
    [ ] Plan created
    [ ] Dependencies mapped

[ ] Phase 2: Testing Setup
    [ ] Test file created
    [ ] Parametrized for all backends
    [ ] NumPy baseline comparison

[ ] Phase 3: Implementation
    [ ] Code converted
    [ ] Tests passing
    [ ] Device handling correct

[ ] Phase 4: Cleanup Pass
    [ ] Code quality review
    [ ] Imports organized
    [ ] Debug code removed

[ ] Phase 5: Performance Testing
    [ ] Benchmark script ready
    [ ] No NumPy regression
    [ ] Backend comparison done
    [ ] Memory profiled

[ ] Phase 6: Final Cleanup
    [ ] Performance issues addressed
    [ ] Code quality review
    [ ] Imports organized
    [ ] Final review complete
    [ ] All tests pass

[ ] Documentation
    [ ] agents/designs/ updated
    [ ] New utilities documented
    [ ] Quirks noted

Memory Benchmarking Guidelines for Array API Backends

Overview

When comparing memory usage across different array backends (NumPy, PyTorch CPU, PyTorch CUDA, CuPy, etc.), it's critical to use consistent measurement methodologies. This document describes common pitfalls and recommended approaches.

The Problem: Inconsistent Memory Measurement

During GPR performance benchmarking (January 2026), we observed what appeared to be a 26-39x memory difference between NumPy and PyTorch CUDA for predict() operations. This turned out to be a measurement artifact, not a real difference.

What Happened

The benchmark used:

  • NumPy: tracemalloc to measure memory
  • PyTorch CUDA: torch.cuda.max_memory_allocated()

These measure fundamentally different things:

Tool What It Measures
tracemalloc NEW Python object allocations during the measured block
torch.cuda.max_memory_allocated() TOTAL GPU memory allocated, including pre-existing tensors

Example: GPR predict() at n=10,000

NumPy reported:    30 MB  (only new allocations during predict)
PyTorch reported: 803 MB  (total GPU memory including stored L_ matrix)

Actual comparison:
- NumPy stored model:     801 MB (L_, X_train_, alpha_ in RAM - not counted!)
- PyTorch stored model:   801 MB (same tensors on GPU - counted!)
- Additional for predict: ~32 MB (both backends, nearly identical)

The L_ matrix (Cholesky factor, 800 MB) was already allocated before predict() started. NumPy's tracemalloc didn't count it because it wasn't a new allocation. PyTorch's memory API counted it because it's part of total GPU memory.

Recommended Approaches

Option 1: Measure Delta Memory (Recommended for Operation Comparison)

Measure memory before and after an operation to get the delta:

# NumPy
import tracemalloc
tracemalloc.start()
# ... operation ...
current, peak = tracemalloc.get_traced_memory()
tracemalloc.stop()
delta = peak  # This is new allocations only

# PyTorch CUDA
import torch
torch.cuda.synchronize()
before = torch.cuda.memory_allocated()
torch.cuda.reset_peak_memory_stats()
# ... operation ...
torch.cuda.synchronize()
peak = torch.cuda.max_memory_allocated()
delta = peak - before  # Delta from baseline

Option 2: Measure Total Stored State (Recommended for Model Size)

To compare model storage requirements, measure the actual tensors/arrays stored:

# NumPy
total_bytes = sum(arr.nbytes for arr in [model.L_, model.alpha_, model.X_train_])

# PyTorch
total_bytes = sum(t.numel() * t.element_size() for t in [model.L_, model.alpha_, model.X_train_])

Note: the tensors/arrays stored depend on the model. Don't just copy the names from the example.

Option 3: Fresh Process Measurement (Most Accurate but Slow)

For the most accurate comparison, measure in a fresh process for each backend:

def measure_in_subprocess(backend, operation):
    """Run measurement in isolated subprocess."""
    # Each measurement starts with clean memory state
    ...

PyTorch-Specific Considerations

CUDA Memory Caching

PyTorch's CUDA allocator caches memory blocks for reuse. This means:

  • torch.cuda.memory_allocated() shows actually used memory
  • torch.cuda.memory_reserved() shows cached memory (may be higher)
  • Use torch.cuda.empty_cache() before measurements to release cached blocks
import torch
import gc

gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
# Now measure...

inference_mode() Does NOT Reduce Memory

Testing showed that torch.inference_mode() has no effect on memory usage for scikit-learn operations. This is because scikit-learn doesn't use PyTorch's autograd - tensors are created without requires_grad=True.

# These use identical memory:
result = model.predict(X)

with torch.inference_mode():
    result = model.predict(X)

Memory Breakdown for GPR (Reference)

For a fitted GaussianProcessRegressor with n training samples:

Component Size Notes
L_ (Cholesky) 8 * n^2 bytes Dominates memory for large n
X_train_ 8 * n * d bytes Training features
alpha_ 8 * n bytes Dual coefficients
y_train_ 8 * n bytes Training targets

For n=10,000, d=10:

  • L_: 800 MB
  • X_train_: 0.8 MB
  • alpha_ + y_train_: 0.16 MB
  • Total: ~801 MB (same for NumPy and PyTorch)

Template Benchmark Function

def benchmark_memory(func, backend, warmup=1, n_reps=3):
    """
    Measure memory usage consistently across backends.
    
    Returns dict with:
    - 'peak_delta': Peak additional memory during operation
    - 'baseline': Memory before operation
    - 'peak_total': Total peak memory (baseline + delta)
    """
    import gc
    
    results = {'peak_delta': [], 'baseline': [], 'peak_total': []}
    
    # Warmup
    for _ in range(warmup):
        func()
    
    for _ in range(n_reps):
        gc.collect()
        
        if backend == "numpy":
            import tracemalloc
            tracemalloc.start()
            func()
            current, peak = tracemalloc.get_traced_memory()
            tracemalloc.stop()
            results['peak_delta'].append(peak)
            results['baseline'].append(0)  # tracemalloc doesn't give baseline
            results['peak_total'].append(peak)
            
        elif backend.startswith("torch"):
            import torch
            if "cuda" in backend:
                torch.cuda.empty_cache()
                torch.cuda.synchronize()
                baseline = torch.cuda.memory_allocated()
                torch.cuda.reset_peak_memory_stats()
                
                func()
                
                torch.cuda.synchronize()
                peak_total = torch.cuda.max_memory_allocated()
                
                results['baseline'].append(baseline)
                results['peak_total'].append(peak_total)
                results['peak_delta'].append(peak_total - baseline)
            else:
                # torch-cpu: no easy memory tracking
                func()
                results['peak_delta'].append(None)
                results['baseline'].append(None)
                results['peak_total'].append(None)
    
    # Return medians
    import numpy as np
    return {
        'peak_delta': np.median([x for x in results['peak_delta'] if x is not None]),
        'baseline': np.median([x for x in results['baseline'] if x is not None]),
        'peak_total': np.median([x for x in results['peak_total'] if x is not None]),
    }

Summary

  1. Always measure delta memory (peak - baseline) when comparing operations
  2. Document what you're measuring - total vs delta, stored model vs operation
  3. Be aware that tracemalloc and torch.cuda APIs measure different things
  4. Clean up before measuring: gc.collect(), torch.cuda.empty_cache()
  5. The stored model (L_ matrix) dominates memory for GPR - this is algorithmic, not backend-specific
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment