You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Scikit-learn's Array API support enables estimators and functions to work with arrays from different libraries (NumPy, CuPy, PyTorch) without modification. This allows computations to run on GPUs when using GPU-backed array libraries.
The implementation follows the Array API Standard, a specification that defines a common API for array manipulation libraries.
Architecture Diagram
flowchart TB
subgraph user [User Code]
UserInput["X, y (torch/cupy/numpy)"]
end
subgraph config [Configuration]
ConfigDispatch["array_api_dispatch=True"]
SciPyEnv["SCIPY_ARRAY_API=1"]
end
subgraph sklearn [scikit-learn]
subgraph utils [sklearn.utils._array_api]
GetNS["get_namespace()"]
GetNSD["get_namespace_and_device()"]
MoveTo["move_to()"]
Helpers["Helper functions"]
end
subgraph estimator [Estimator/Function]
Fit["fit() / transform() / predict()"]
Tags["__sklearn_tags__()"]
end
subgraph externals [sklearn.externals]
AAC["array_api_compat"]
AAX["array_api_extra"]
end
end
subgraph backends [Array Backends]
NumPy["numpy"]
CuPy["cupy"]
Torch["torch"]
Strict["array-api-strict"]
end
UserInput --> Fit
ConfigDispatch --> GetNS
SciPyEnv --> GetNS
Fit --> GetNS
GetNS --> AAC
AAC --> backends
Helpers --> AAX
MoveTo --> backends
Loading
Configuration Requirements
1. SciPy Array API Environment Variable
Before importing scipy or scikit-learn:
export SCIPY_ARRAY_API=1
This is required because scikit-learn uses scipy for some operations, and scipy needs this flag to enable its own array API support.
2. Enable Array API Dispatch
importsklearnsklearn.set_config(array_api_dispatch=True)
# Or use a context manager:fromsklearnimportconfig_contextwithconfig_context(array_api_dispatch=True):
model.fit(X_gpu, y_gpu)
Supported Libraries
Library
Device Support
Notes
NumPy
CPU
Default, always works
CuPy
CUDA GPU
Requires CUDA-capable GPU
PyTorch
CPU, CUDA, MPS, XPU
Most flexible device support
array-api-strict
CPU (simulated devices)
For testing/development only
Device Considerations
PyTorch MPS (Apple Silicon GPU): Limited to float32, some operations fall back to CPU
float64 support: Not all devices support float64 (e.g., MPS). Scikit-learn falls back to float32 when needed.
get_namespace_and_device() - Also returns device info
move_to() - Transfer arrays between namespaces/devices
Helper functions for operations not in array API spec
Namespace Flow
sequenceDiagram
participant User
participant Estimator
participant get_namespace
participant array_api_compat
participant Backend
User->>Estimator: fit(X_gpu, y)
Estimator->>get_namespace: get_namespace(X_gpu)
get_namespace->>get_namespace: Check array_api_dispatch config
get_namespace->>array_api_compat: get_namespace(X_gpu)
array_api_compat->>array_api_compat: Introspect array type
array_api_compat-->>get_namespace: xp (torch namespace)
get_namespace-->>Estimator: (xp, is_array_api=True)
Estimator->>Backend: xp.mean(X_gpu), xp.linalg.svd(X_gpu)
Backend-->>Estimator: Results on GPU
Estimator-->>User: Fitted model with GPU attributes
Loading
Input Conversion Rules
Estimators: "Everything Follows X"
When fitting an estimator with array API inputs:
The namespace and device are determined from X
All other inputs (y, sample_weight, constructor arrays) are converted to match X
Fitted attributes are stored in the same namespace/device as X
predict()/transform() expect inputs from the same namespace/device
# Example: X on GPU, y on CPU -> y converted to GPUwithconfig_context(array_api_dispatch=True):
lda.fit(X_cuda, y_numpy) # y is moved to CUDA# lda.coef_ is now a CUDA tensor
Metrics: "Everything Follows y_pred"
When calling scoring functions:
The namespace and device are determined from y_pred
All other inputs (y_true, sample_weight) are converted to match y_pred
Scalar outputs return Python floats
Multi-value outputs return arrays in the same namespace
# Example: y_pred on GPU, y_true on CPU -> y_true converted to GPUwithconfig_context(array_api_dispatch=True):
score=accuracy_score(y_true_numpy, y_pred_cuda)
# Returns Python float, not tensor
When Array API Dispatch is Disabled
When array_api_dispatch=False (the default):
get_namespace() always returns the numpy-compat wrapper
is_array_api_compliant is always False
Scikit-learn expects NumPy arrays and calls np.asarray() on inputs
Some array types (like GPU tensors) may fail or trigger implicit copies
Experimental Status
Array API support is currently experimental. This means:
This document describes the patterns and principles for adding Array API support to scikit-learn estimators. Follow these guidelines when converting an estimator to support GPU arrays.
Prerequisites for an Estimator
An estimator can be converted to support array API if:
It primarily uses NumPy operations (not Cython)
It doesn't rely on scipy functions that lack array API support
The algorithm can be expressed using array API standard operations
Step-by-Step Implementation Pattern
Step 1: Add Required Imports
fromsklearn.utils._array_apiimport (
get_namespace,
get_namespace_and_device,
device,
supported_float_dtypes,
# Add any helper functions you need:_average,
_nanmean,
_convert_to_numpy,
)
Step 5: Replace NumPy Operations with Namespace Operations
NumPy
Array API Equivalent
np.mean(X)
xp.mean(X)
np.std(X)
xp.std(X)
np.zeros((n, m))
xp.zeros((n, m), device=device(X))
np.ones_like(X)
xp.ones_like(X)
np.concatenate([a, b])
xp.concat([a, b])
np.dot(a, b)
a @ b or xp.matmul(a, b)
np.sum(X, axis=0)
xp.sum(X, axis=0)
np.unique(y)
xp.unique_values(y)
np.argsort(X)
xp.argsort(X, stable=True)
scipy.linalg.svd(X)
xp.linalg.svd(X)
Step 6: Handle Array Creation with Device
Always specify device when creating new arrays:
# Wrong - may create on wrong devicezeros=xp.zeros((n, m))
# Correct - explicitly set devicezeros=xp.zeros((n, m), device=device(X), dtype=X.dtype)
# Also correct - infer from existing arrayzeros=xp.zeros_like(X)
Step 7: Handle Conditional NumPy/SciPy Usage
Some operations require different implementations:
def_solve_svd(self, X, y):
xp, is_array_api_compliant=get_namespace(X)
ifis_array_api_compliant:
# Use array API linalgsvd=xp.linalg.svdelse:
# Fall back to scipy for NumPy arrayssvd=scipy.linalg.svd_, S, Vt=svd(X, full_matrices=False)
deffit(self, X, y):
xp, _=get_namespace(X)
# Compute and store in same namespaceself.mean_=xp.mean(X, axis=0) # Will be torch tensor if X isself.classes_=xp.unique_values(y)
returnself
Converting Fitted Estimators
To move a fitted estimator's attributes to a different format:
fromsklearn.utils._array_apiimport_estimator_with_converted_arrays# Convert CUDA tensors to NumPycupy_to_ndarray=lambdaarray: array.get()
lda_numpy=_estimator_with_converted_arrays(lda_gpu, cupy_to_ndarray)
# Wrong - Python float may not broadcast correctlyX=X*0.5# Correct - explicit array creationX=X*xp.asarray(0.5, dtype=X.dtype, device=device(X))
4. Boolean Indexing
Array API has limited boolean indexing support:
# This may not work in all namespacesX_subset=X[mask]
# Use xp.take with integer indices insteadindices=xp.arange(X.shape[0], device=device(X))
X_subset=xp.take(X, indices[mask], axis=0)
5. Calling Methods That Don't Exist
Not all NumPy methods exist in array API:
# These don't exist in array API:# - np.nanmean, np.nanmin, np.nanmax# - np.bincount# - np.unique (use unique_values, unique_inverse, etc.)# Use sklearn helpers instead:fromsklearn.utils._array_apiimport_nanmean, _bincount
Operations Not in Array API Standard
For operations not in the standard, use sklearn's helper functions:
Some estimators only support array API for certain parameter combinations:
# PCA: array API only with certain solversdef__sklearn_tags__(self):
tags=super().__sklearn_tags__()
solver=getattr(self, "_fit_svd_solver", self.svd_solver)
tags.array_api_support=solvernotin ["arpack", "randomized"] or (
solver=="randomized"andself.power_iteration_normalizer=="QR"
)
returntags
Document these restrictions in the docstring and raise NotImplementedError for unsupported configurations:
deffit(self, X, y):
xp, is_array_api=get_namespace(X)
ifis_array_apiandself.init_paramsnotin ["random", "random_from_data"]:
raiseNotImplementedError(
f"init_params={self.init_params!r} is not supported with ""array_api_dispatch enabled."
)
This document describes the patterns for adding Array API support to scikit-learn metrics and standalone functions. Metrics follow a different convention than estimators: y_pred leads instead of X leads.
The "y_pred Leads" Convention
For scoring functions:
The namespace and device are determined from y_pred
All other inputs (y_true, sample_weight, etc.) are converted to match y_pred
This enables metrics to work within pipelines where X moves between devices
Why y_pred Leads
Consider a pipeline where:
X starts on CPU (with string categorical data)
A transformer encodes and moves X to GPU
A GPU-based estimator makes predictions
The predictions (y_pred) will be on GPU, but y_true may still be on CPU. By following y_pred, the metric automatically handles this mismatch.
Implementation Pattern for Metrics
Step 1: Get Namespace from y_pred
fromsklearn.utils._array_apiimport (
get_namespace_and_device,
move_to,
_average,
_is_numpy_namespace,
)
defmy_metric(y_true, y_pred, *, sample_weight=None):
# Determine namespace and device from y_predxp, _, device_=get_namespace_and_device(y_pred)
Step 2: Convert Other Inputs with move_to()
defmy_metric(y_true, y_pred, *, sample_weight=None):
xp, _, device_=get_namespace_and_device(y_pred)
# Convert y_true and sample_weight to match y_predy_true, sample_weight=move_to(y_true, sample_weight, xp=xp, device=device_)
Multi-Value Outputs: Return Array in Same Namespace
When a function needs to return multiple values (e.g., per-class scores, a matrix), return a single array in the same namespace as the input:
defconfusion_matrix(y_true, y_pred, ...):
# ... computation ...# Returns a 2D array (n_classes, n_classes) in the input namespacereturnxp.asarray(cm, device=device_)
defper_class_scores(y_true, y_pred, ...):
# ... computation ...# Returns a 1D array with one value per class, in the input namespacereturnscores# shape (n_classes,), in xp namespace
The array contains all the values; callers index into it as needed.
A more complex example showing fallback to NumPy for efficiency:
defconfusion_matrix(y_true, y_pred, *, labels=None, sample_weight=None, normalize=None):
"""Compute confusion matrix."""# Get namespace for output consistencyxp, _, device_=get_namespace_and_device(y_true, y_pred, labels, sample_weight)
# Validate inputsy_true=check_array(y_true, dtype=None, ensure_2d=False, ...)
y_pred=check_array(y_pred, dtype=None, ensure_2d=False, ...)
# Convert to NumPy for efficient computation# (scipy.sparse.coo_matrix is more efficient for this)y_true=_convert_to_numpy(y_true, xp)
y_pred=_convert_to_numpy(y_pred, xp)
ifsample_weightisnotNone:
sample_weight=_convert_to_numpy(sample_weight, xp)
# ... compute confusion matrix using NumPy/scipy ...# Convert result back to input namespace for consistencyreturnxp.asarray(cm, device=device_)
Complete Example: balanced_accuracy_score
Shows handling namespace-specific behavior:
defbalanced_accuracy_score(y_true, y_pred, *, sample_weight=None, adjusted=False):
"""Compute the balanced accuracy."""# Compute confusion matrix (already handles namespace)C=confusion_matrix(y_true, y_pred, sample_weight=sample_weight)
xp, _, device_=get_namespace_and_device(y_pred, y_true)
# Handle array-api-strict quirksif_is_xp_namespace(xp, "array_api_strict"):
# array_api_strict only supports floating point for __truediv__C=xp.astype(C, _max_precision_float_dtype(xp, device=device_), copy=False)
# Handle division warnings for NumPycontext_manager= (
np.errstate(divide="ignore", invalid="ignore")
if_is_numpy_namespace(xp)
elsenullcontext()
)
withcontext_manager:
per_class=xp.linalg.diagonal(C) /xp.sum(C, axis=1)
# Handle NaN for missing classesifxp.any(xp.isnan(per_class)):
warnings.warn("y_pred contains classes not in y_true")
per_class=per_class[~xp.isnan(per_class)]
score=xp.mean(per_class)
ifadjusted:
n_classes=per_class.shape[0]
chance=1/n_classesscore-=chancescore/=1-chancereturnfloat(score)
Handling Helper Functions
When a metric calls internal helper functions, pass the namespace:
def_check_targets(y_true, y_pred, sample_weight=None):
"""Check targets are valid and consistent."""xp, _=get_namespace(y_true, y_pred)
# ... use xp throughoutreturny_type, y_true, y_pred, sample_weightdefmy_metric(y_true, y_pred, *, sample_weight=None):
xp, _, device_=get_namespace_and_device(y_pred)
y_true, sample_weight=move_to(y_true, sample_weight, xp=xp, device=device_)
# Helper function will get namespace from its inputsy_type, y_true, y_pred, sample_weight=_check_targets(
y_true, y_pred, sample_weight
)
Pairwise Metrics
For pairwise distance/kernel functions:
defeuclidean_distances(X, Y=None, *, Y_norm_squared=None, squared=False, X_norm_squared=None):
xp, _, device_=get_namespace_and_device(X, Y)
# Move Y to match X if neededifYisnotNone:
Y=move_to(Y, xp=xp, device=device_)
# Compute with namespace operations# ...returndistances# Returns array in same namespace as X
defmulti_value_metric(y_true, y_pred):
xp, _, device_=get_namespace_and_device(y_pred)
y_true=move_to(y_true, xp=xp, device=device_)
value1=xp.mean(y_true==y_pred)
value2=xp.sum(y_true!=y_pred)
# Return arrays, not floats, for multiple valuesreturnvalue1, value2
Pattern 4: Metric with Fallback to NumPy
For efficiency, some metrics compute in NumPy then convert back:
defefficient_metric(y_true, y_pred):
xp, _, device_=get_namespace_and_device(y_pred)
# Convert to NumPy for efficient computationy_true_np=_convert_to_numpy(y_true, xp)
y_pred_np=_convert_to_numpy(y_pred, xp)
# Use efficient NumPy/SciPy operationsresult_np=scipy_function(y_true_np, y_pred_np)
# Convert back to original namespacereturnxp.asarray(result_np, device=device_)
Testing Metrics with Array API
See the testing patterns document for details. Key points:
Use yield_namespace_device_dtype_combinations() for parametrization
Compare results between NumPy and array API implementations
Handle floating-point tolerance differences for different dtypes
This document describes how to write tests for array API support in scikit-learn. Testing ensures that estimators and functions work correctly across different array libraries and devices.
deftest_mixed_inputs(array_namespace, device, dtype_name):
xp=_array_api_for_tests(array_namespace, device)
# X on GPU, y on CPUX_xp=xp.asarray(np.random.randn(100, 10).astype(dtype_name), device=device)
y_np=np.array([0, 1] *50)
withconfig_context(array_api_dispatch=True):
clf=LinearDiscriminantAnalysis()
clf.fit(X_xp, y_np) # y should be converted to match X# Predictions should be on same device as Xpredictions=clf.predict(X_xp)
assertdevice(predictions) ==device(X_xp)
Different dtypes require different tolerances. Use _atol_for_type() for consistent,
dtype-appropriate absolute tolerances across the codebase.
The _atol_for_type() Utility (Preferred)
fromsklearn.utils._array_apiimport_atol_for_typedeftest_numerical_accuracy(array_namespace, device, dtype_name):
xp=_array_api_for_tests(array_namespace, device)
# Use _atol_for_type for absolute tolerance# Formula: numpy.finfo(dtype).eps * 1000# - float64: ~2.22e-13# - float32: ~1.19e-4atol=_atol_for_type(dtype_name)
# Common rtol patternrtol=1e-5ifdtype_name=="float64"else1e-4assert_allclose(result_np, result_xp_np, rtol=rtol, atol=atol)
Why atol Matters for Comparisons Against Zero
When using assert_allclose(actual, desired, rtol=...), the check is:
|actual - desired| <= atol + rtol * |desired|
Problem: When desired=0 and default atol=0:
The check becomes |actual| <= 0
Any tiny floating-point noise (e.g., 1e-15) causes failure
The relative difference becomes inf (division by zero)
Solution: Always use atol=_atol_for_type(dtype_name) when comparing values that
may include exact zeros, such as gradient arrays or covariance matrices.
Recommended Pattern for Gradient Tests
fromsklearn.utils._array_apiimport_atol_for_typedeftest_kernel_gradient_array_api(array_namespace, device, dtype_name):
# ... compute K_grad_np and K_grad_xp_np ...rtol=1e-5ifdtype_name=="float64"else1e-4assert_allclose(K_xp_np, K_np, rtol=rtol)
# Use atol for gradients since they may contain exact zerosassert_allclose(K_grad_xp_np, K_grad_np, rtol=rtol, atol=_atol_for_type(dtype_name))
This document provides a reference for the utility functions in sklearn/utils/_array_api.py. These functions provide the foundation for array API support in scikit-learn. It is possible that sklearn/utils/_array_api.py contains functions not shown here, always double check as the code is the truth.
Purpose of sklearn/utils/_array_api.py
The purpose of this file is to provide a home for helper functions. This includes general utilities like get_namespace, a customized version of array_api_compat's implementation. It also includes functions like nanmin that exist in one or more array libraries, but are not part of the array API standard.
Adding to sklearn/utils/_array_api.py
When encountering a function that is present in Numpy but not part of the array API standard, then it is time to create an implementation of that function in sklearn/utils/_array_api.py. Examples of this are listed in the "Array Operations Not in Standard" section.
To implement a new function the ideal path is to take the code of that function from Numpy and convert it to use building blocks and idioms which are part of the array API. It is ok to only implement a subset of the functionality that the Numpy version provides, be pragmatic. The function signature should mirror Numpy's. It is ok to miss out some arguments or only support a subset of the possible values for an argument.
remove_none: Whether to ignore None values (default: True)
remove_types: Types to ignore, default: (str, list, tuple)
xp: Pre-computed namespace to skip inspection
Returns:
namespace: The array namespace module (e.g., torch, cupy)
is_array_api_compliant: True if dispatch is enabled, False otherwise
Behavior:
When array_api_dispatch=False: Always returns (numpy_compat, False)
When array_api_dispatch=True: Returns the actual namespace of input arrays
Sparse arrays are filtered out
Pandas DataFrames/Series are filtered out
Example:
importtorchfromsklearn.utils._array_apiimportget_namespaceX=torch.randn(100, 10)
withconfig_context(array_api_dispatch=True):
xp, is_compliant=get_namespace(X)
# xp is the torch namespace# is_compliant is True
Single array if one input, tuple if multiple inputs
Notes:
Uses DLPack protocol when available for zero-copy transfer
Falls back to NumPy intermediate when DLPack fails
Sparse arrays are passed through unchanged (only for NumPy target)
None values are passed through unchanged
Example:
fromsklearn.utils._array_apiimportmove_to, get_namespace_and_deviceX_cuda=torch.randn(100, 10, device="cuda")
y_numpy=np.array([0, 1, 1, 0, ...])
xp, _, dev=get_namespace_and_device(X_cuda)
y_cuda=move_to(y_numpy, xp=xp, device=dev)
# y_cuda is now a torch.Tensor on CUDA
Rationale: The device() function in _array_api.py handles edge cases consistently across all array backends, including NumPy arrays which don't have a .device attribute. Using ad-hoc patterns like getattr(X, "device", None) can lead to inconsistent behavior and misses the benefit of centralized maintenance.
Avoid NumPy Scalar Types in Array Arithmetic
When multiplying arrays by scalars that come from NumPy or SciPy functions (e.g., scipy.special.gamma), convert them to Python float first.
❌ Bad: Using numpy scalar directly
fromscipy.specialimportgammacoef= (2** (1.0-nu)) /gamma(nu) # numpy.float64K=coef*array# May corrupt device info for array_api_strict
✅ Good: Convert to Python float first
fromscipy.specialimportgammacoef=float((2** (1.0-nu)) /gamma(nu)) # Python floatK=coef*array# Preserves device info
Rationale: When a numpy.float64 is multiplied with an array_api_strict array, the result may have a corrupted device attribute (e.g., 'cpu' as a string instead of a proper Device object). Python float scalars broadcast correctly and preserve device information.
Import Placement
Place all imports at the top of the file, not inline within functions.
❌ Bad: Inline imports of standard dependencies
defsome_function(X, xp):
fromscipy.spatial.distanceimportcdist# Should be at top of filereturncdist(X, X)
Exception: Optional array library dependencies (torch, cupy, array_api_strict) should be imported inline, only when we know the user has them installed. This is typically determined by checking if the input array is from that namespace.
✅ Good: Conditional inline import of optional dependencies
fromscipy.spatial.distanceimportcdistdef_cdist(X, Y, metric="euclidean", xp=None):
xp, _, device_=get_namespace_and_device(X, Y, xp=xp)
if_is_numpy_namespace(xp):
returnxp.asarray(cdist(X, Y, metric=metric))
if_is_xp_namespace(xp, "torch"):
# No need to import torch as it is available as `xp` alreadyreturnxp.cdist(X, Y, p=2)
if_is_xp_namespace(xp, "cupy"):
fromcupyx.scipy.spatial.distanceimportcdist# Only when input is cupyreturncdist(X, Y, metric=metric)
Rationale: Optional dependencies like torch and cupy are not required for scikit-learn to function. Importing them at module load time would cause ImportError for users who don't have them installed. By importing conditionally inside functions, we only attempt the import when we know the dependency is available (because the user passed an array from that library).
Use Functional Form for Matrix Transpose
Use xp.matrix_transpose(L) instead of the .mT attribute for matrix transposition.
❌ Bad: Using .mT attribute
# CuPy arrays don't have .mT attributeL_transposed=L.mT
✅ Good: Use functional form
L_transposed=xp.matrix_transpose(L)
Rationale: The .mT attribute (matrix transpose) was added in the Array API 2022.12 standard, but CuPy hasn't fully adopted it. CuPy arrays raise AttributeError: 'ndarray' object has no attribute 'mT'. The functional form xp.matrix_transpose() is provided by array_api_compat and works consistently across all backends (NumPy, PyTorch, CuPy, etc.).
Note: For 2D arrays, .mT and .T are equivalent, but .mT specifically transposes only the last two dimensions (important for batched operations). The functional xp.matrix_transpose() has the same semantics as .mT.
Preserve Input Dtype on Limited Devices
Some devices don't support float64 (double precision). Don't hardcode float64 conversions.
❌ Bad: Hardcoded float64 conversion
defprocess_array(X, xp):
# This will fail on MPS devicesX=xp.astype(X, xp.float64)
returnX
✅ Good: Preserve input dtype
defprocess_array(X, xp):
# Use X.dtype to stay compatible with devices that lack float64ifX.dtype!=target_array.dtype:
X=xp.astype(X, target_array.dtype)
returnX
Rationale: Apple's MPS (Metal Performance Shaders) device does not support float64. Attempting to convert to float64 on MPS raises TypeError: Cannot convert a MPS Tensor to float64 dtype. By preserving the input array's dtype, code remains compatible with all devices.
Testing
Running Tests with MPS Device
On Apple Silicon Macs, you can run tests using the MPS (Metal Performance Shaders) device for GPU acceleration:
PYTORCH_ENABLE_MPS_FALLBACK=1 SCIPY_ARRAY_API=1 pytest -v -k "array_api and torch"
Environment variables:
PYTORCH_ENABLE_MPS_FALLBACK=1: Allows PyTorch to fall back to CPU for operations not supported on MPS (prevents hard errors)
SCIPY_ARRAY_API=1: Enables Array API support in scipy
Note: Without PYTORCH_ENABLE_MPS_FALLBACK=1, tests may fail with errors about unsupported operations. The fallback allows tests to complete while still exercising MPS for supported operations.
Devices Without float64 Support
Some GPU devices do not support double precision (float64) floating point numbers:
Device
float64 Support
Notes
CPU
✅ Yes
Full support
CUDA (NVIDIA)
✅ Yes
Full support
MPS (Apple Silicon)
❌ No
Only float32 supported
XPU (Intel)
✅ Yes
Full support
When writing Array API compatible code:
Never hardcode float64: Use X.dtype to preserve the input array's dtype
Use input dtype for new arrays: When creating new arrays, use dtype=X.dtype and device=array_device(X)
Test with float32 configurations: The test suite includes float32 parameter combinations to catch float64 assumptions
Example of creating compatible arrays:
xp, _, device_=get_namespace_and_device(X)
# Create array matching input dtype and deviceones=xp.ones(shape, dtype=X.dtype, device=device_)
# Convert scalar to array on correct device with correct dtype value=xp.asarray(1.0, dtype=X.dtype, device=device_)
Scipy Dtype Preservation
Several scipy functions (scipy.linalg.cholesky, scipy.linalg.cho_solve, scipy.spatial.distance.pdist, etc.) upcast float32 inputs to float64 for their computations. This causes dtype mismatches when comparing numpy results with array API results (which preserve the input dtype).
The utility functions in _array_api.py (_cholesky, _cho_solve, _cdist, _pdist, _squareform) handle this by casting the result back to the input dtype when the input was floating point:
def_cholesky(covariance, xp):
if_is_numpy_namespace(xp):
# scipy may upcast float32 to float64; cast back to preserve input dtyperesult=scipy.linalg.cholesky(covariance, lower=True)
returnresult.astype(covariance.dtype, copy=False)
else:
returnxp.linalg.cholesky(covariance)
Note: The dtype cast-back only happens for floating point inputs. Integer inputs (e.g., from structured data with string features) are not cast back to avoid precision loss.
This document provides a structured workflow for converting scikit-learn components to support the Array API. It serves as a checklist to ensure important steps aren't missed and work proceeds in the right order.
Add new utility functions to sklearn/utils/_array_api.py as needed
Run tests frequently during development
Handle device placement correctly (use device() utility, not getattr)
Key implementation patterns:
fromsklearn.utils._array_apiimport (
get_namespace_and_device,
move_to,
_convert_to_numpy,
)
deffit(self, X, y):
xp, is_array_api, device_=get_namespace_and_device(X)
# Move y to match X's namespace/devicey=move_to(y, xp=xp, device=device_)
# Use xp instead of np throughoutresult=xp.sum(X, axis=0)
# Store fitted attributes in the same namespaceself.mean_=xp.mean(X, axis=0)
Common pitfalls to avoid:
Don't use getattr(X, "device", None) - use the device() utility
Don't multiply arrays by NumPy scalars - convert to Python float first
Don't hardcode float64 - preserve input dtype for MPS compatibility
Phase 4: Cleanup Pass
After implementation is working:
Review for code quality and DRY principles
Ensure all imports are at the top of the file (except optional array library imports)
Remove any debug code or print statements
Check that code follows existing conventions in the codebase
Memory Benchmarking Guidelines for Array API Backends
Overview
When comparing memory usage across different array backends (NumPy, PyTorch CPU, PyTorch CUDA, CuPy, etc.), it's critical to use consistent measurement methodologies. This document describes common pitfalls and recommended approaches.
The Problem: Inconsistent Memory Measurement
During GPR performance benchmarking (January 2026), we observed what appeared to be a 26-39x memory difference between NumPy and PyTorch CUDA for predict() operations. This turned out to be a measurement artifact, not a real difference.
What Happened
The benchmark used:
NumPy: tracemalloc to measure memory
PyTorch CUDA: torch.cuda.max_memory_allocated()
These measure fundamentally different things:
Tool
What It Measures
tracemalloc
NEW Python object allocations during the measured block
torch.cuda.max_memory_allocated()
TOTAL GPU memory allocated, including pre-existing tensors
Example: GPR predict() at n=10,000
NumPy reported: 30 MB (only new allocations during predict)
PyTorch reported: 803 MB (total GPU memory including stored L_ matrix)
Actual comparison:
- NumPy stored model: 801 MB (L_, X_train_, alpha_ in RAM - not counted!)
- PyTorch stored model: 801 MB (same tensors on GPU - counted!)
- Additional for predict: ~32 MB (both backends, nearly identical)
The L_ matrix (Cholesky factor, 800 MB) was already allocated before predict() started. NumPy's tracemalloc didn't count it because it wasn't a new allocation. PyTorch's memory API counted it because it's part of total GPU memory.
Recommended Approaches
Option 1: Measure Delta Memory (Recommended for Operation Comparison)
Measure memory before and after an operation to get the delta:
# NumPyimporttracemalloctracemalloc.start()
# ... operation ...current, peak=tracemalloc.get_traced_memory()
tracemalloc.stop()
delta=peak# This is new allocations only# PyTorch CUDAimporttorchtorch.cuda.synchronize()
before=torch.cuda.memory_allocated()
torch.cuda.reset_peak_memory_stats()
# ... operation ...torch.cuda.synchronize()
peak=torch.cuda.max_memory_allocated()
delta=peak-before# Delta from baseline
Option 2: Measure Total Stored State (Recommended for Model Size)
To compare model storage requirements, measure the actual tensors/arrays stored:
Note: the tensors/arrays stored depend on the model. Don't just copy the names from
the example.
Option 3: Fresh Process Measurement (Most Accurate but Slow)
For the most accurate comparison, measure in a fresh process for each backend:
defmeasure_in_subprocess(backend, operation):
"""Run measurement in isolated subprocess."""# Each measurement starts with clean memory state
...
PyTorch-Specific Considerations
CUDA Memory Caching
PyTorch's CUDA allocator caches memory blocks for reuse. This means:
torch.cuda.memory_allocated() shows actually used memory
torch.cuda.memory_reserved() shows cached memory (may be higher)
Use torch.cuda.empty_cache() before measurements to release cached blocks
importtorchimportgcgc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
# Now measure...
inference_mode() Does NOT Reduce Memory
Testing showed that torch.inference_mode() has no effect on memory usage for scikit-learn operations. This is because scikit-learn doesn't use PyTorch's autograd - tensors are created without requires_grad=True.
# These use identical memory:result=model.predict(X)
withtorch.inference_mode():
result=model.predict(X)
Memory Breakdown for GPR (Reference)
For a fitted GaussianProcessRegressor with n training samples: