Skip to content

Instantly share code, notes, and snippets.

@andres-fr
Last active September 24, 2025 00:39
Show Gist options
  • Select an option

  • Save andres-fr/85e8941eceb3053ddf4f2ef0734d582b to your computer and use it in GitHub Desktop.

Select an option

Save andres-fr/85e8941eceb3053ddf4f2ef0734d582b to your computer and use it in GitHub Desktop.
Out-of-core QR decompositions using dask
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""Exploring out-of-core QR decompositions via dask.
conda create -n dask python=3.10
conda activate dask
pip install numpy
pip install h5py
pip install dask
pip install psutil # to monitor RAM
pip install matplotlib # optional
python dask_qr_test.py
The snippet below runs, but it exhausts RAM memory
so it is not a good test for out-of-core behaviour.
Artificially limiting the available RAM memory
proved to be challenging:
* [client-free (the one shown below)] This version
does not seem to respond to limiting the memory
via dask options. Using the OS to artificially
limit memory (e.g. via ulimit, cgroups, or
resource.setrlimit) results in dask trying to
go beyond limit, and OOM crashing. Simply put,
it seems that this mode just takes as much
memory as it can, and crashes otherwise. This
could be fine if we didn't aim to limit memory
usage to verify out-of-core behaviour.
* [distributed client]: If we add a cluster-client
pair like the ones below, they come with explicit
memory control, and, if we run e.g.
``r.compute()`` it will work like a charm. The
problem is that we want to write Q in-place to
the HDF5 file. And these distributed clients
always attempt to serialize it, even if we
specify that we just want 1 thread. This results
in an error since plain HDF5 files cannot be
serialized or concurrently written.
Thus, we see that the two above options are mutually
exclusive: there seems to be no way of imposing
a hard memory limit *and* writing to HDF5 files
in-place using dask. Ideally, we would have something
like a single-threaded cluster/client pattern that
would not require multiprocessing, thus allowing
to use non-serializable objects.
cluster = LocalCluster(
n_workers=1,
threads_per_worker=1,
processes=False,
memory_limit="2GB",
asynchronous=False,
)
client = Client(cluster)
"""
from time import time
import os
from tempfile import TemporaryDirectory
import numpy as np
import h5py
import dask
import dask.array as da
import matplotlib.pyplot as plt
import psutil
# ##############################################################################
# # HELPERS
# ##############################################################################
def qr_inplace_h5(arr, chunks="auto"):
""" """
ooc = da.from_array(arr, chunks=chunks)
q, r = da.linalg.qr(ooc)
store_q = da.store(q, arr, compute=False)
_, r_vals = dask.compute(store_q, r)
return r_vals
def dist(ori, appr, relative=False):
""" """
result = np.linalg.norm(ori - appr)
if relative:
result /= np.linalg.norm(ori)
return result
def plot_covmat(tall_arr, log=False):
""" """
plt.clf()
G = tall_arr.conj().T @ tall_arr
plt.imshow(np.log(G) if log else G)
plt.show()
return G
def print_mem(process, prefix=""):
""" """
rss = process.memory_info().rss / 1e9
print(f"{prefix}RSS = {rss:.2f} GB")
# ##############################################################################
# # GLOBALS
# ##############################################################################
TEST_MEMORY = True
#
RUNTIME_SHAPE = (20000, 300)
RUNTIME_CHUNKS = "auto"
#
MEMORY_SHAPE = (100_000, 500) # 50M entries
MEMORY_CHUNKS = (10_000, MEMORY_SHAPE[1])
#
H5NAME = "A.h5"
DSNAME = "data"
# ##############################################################################
# # MAIN ROUTINE
# ##############################################################################
def runtimes_and_correctness():
""" """
# create dask arrays both from NP (in ram) and HDF5 (in disk)
A = np.random.randn(*RUNTIME_SHAPE)
A_orig = A.copy()
with TemporaryDirectory() as tmpdir:
# save numpy array into temporary HDF5 file
h5path = os.path.join(tmpdir, H5NAME)
da.from_array(A, chunks=CHUNKS).to_hdf5(
h5path, DSNAME, chunks=True if CHUNKS == "auto" else CHUNKS
)
# load temp file, run QR on both numpy and HDF5
# check for correctness and runtime
with h5py.File(h5path, mode="r+") as f:
# pure numpy
t0 = time()
Q1, R1 = np.linalg.qr(A)
elapsed1 = time() - t0
# dask on numpy
t0 = time()
R2 = qr_inplace_h5(A, chunks=CHUNKS)
elapsed2 = time() - t0
Q2 = A
# dask on HDF5
t0 = time()
R3 = qr_inplace_h5(f[DSNAME], chunks=CHUNKS)
elapsed3 = time() - t0
Q3 = f[DSNAME][:]
# All should take about the same time if chunks=shape
print(f"[Pure NP] s={elapsed1}, err={dist(A_orig, Q1 @ R1, True)}")
print(f"[Dask NP] s={elapsed2}, err={dist(A_orig, Q2 @ R2, True)}")
print(f"[Dask H5] s={elapsed3}, err={dist(A_orig, Q3 @ R3, True)}")
#
breakpoint()
# plot_covmat(A_orig)
# plot_covmat(Q1)
# plot_covmat(Q2)
# plot_covmat(Q3)
def ooc_memory():
""" """
process = psutil.Process(os.getpid())
print_mem(process, "Initial memory: ")
with TemporaryDirectory() as tmpdir:
h5path = os.path.join(tmpdir, H5NAME)
with h5py.File(h5path, mode="w") as f:
h5 = f.create_dataset(
DSNAME,
shape=MEMORY_SHAPE,
chunks=MEMORY_CHUNKS,
dtype="float64",
)
print_mem(process, "RAM After creating h5: ")
#
A = da.random.standard_normal(
size=MEMORY_SHAPE, chunks=MEMORY_CHUNKS
)
da.store(A, h5) # this streams chunks into the file
print_mem(process, "RAM After writing h5: ")
#
h5_gb = os.path.getsize(h5path) / 1e9
print(f"h5 size on disk before QR: {h5_gb} GB")
#
R3 = qr_inplace_h5(f[DSNAME], chunks=MEMORY_CHUNKS)
print_mem(process, "RAM After QR: ")
print(f"h5 size on disk after QR: {h5_gb} GB")
#
print(f"If RAM didn't surpass {h5_gb}, QR is not in-core")
if __name__ == "__main__":
if TEST_MEMORY:
ooc_memory()
else:
runtimes_and_correctness()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment