spectralbrain.backends.gpu#

GPU compute backends — CuPy, JAX, NumPyro, and VRAM management.

This module provides:

  1. CupyBackend — drop-in GPU replacement for NumpyBackend using CuPy (cupy-cuda13x).

  2. JaxBackend — GPU backend with jit and vmap for batch-subject spectral descriptor computation.

  3. NumPyroSampler — GPU-accelerated Bayesian MCMC via JAX.

  4. VRAM management — monitoring, cache clearing, defragmentation, garbage collection, and a memory-guarded context manager.

All dependencies are lazy-imported. If CuPy or JAX is missing, the module still imports successfully — only instantiation of the backends raises ImportError.

Functions

get_gpu_backend([name])

Factory for GPU compute backends.

get_gpu_bayesian_sampler([backend])

Factory for GPU Bayesian samplers.

vram_clear()

Clear CUDA memory caches across all known GPU frameworks.

vram_defrag([device_id])

Attempt CUDA memory defragmentation.

vram_gc([device_id])

Full GPU garbage collection: Python GC + VRAM clear + defrag.

vram_guard([min_free_gb, device_id, ...])

Context manager for VRAM-safe GPU operations.

vram_monitor([device_id, label])

Log current VRAM usage (one-shot, for debugging).

vram_status([device_id])

Query current VRAM usage.

Classes

BlackjaxSampler([num_warmup, num_samples, ...])

GPU-accelerated Bayesian NUTS via BlackJAX.

CupyBackend([device_id])

GPU compute backend using CuPy.

JaxBackend([device])

GPU backend using JAX with jit and vmap.

NumPyroSampler([num_warmup, num_samples, ...])

GPU-accelerated Bayesian MCMC using NumPyro + JAX.

TorchBackend([device])

GPU compute backend using PyTorch.

VRAMInfo(device_name, device_id, total_gb, ...)

Snapshot of GPU VRAM usage.

class spectralbrain.backends.gpu.BlackjaxSampler(num_warmup=1000, num_samples=2000, num_chains=4, seed=42)[source]#

Bases: object

GPU-accelerated Bayesian NUTS via BlackJAX.

BlackJAX is a low-level sampler that operates on a log-density function rather than a model object, which makes it composable and fast under JAX jit/vmap on the GPU. This wrapper runs the standard window-adaptation → NUTS pipeline and (optionally) vectorises independent chains with jax.vmap().

Parameters:
  • num_warmup (int) – Window-adaptation (tuning) steps.

  • num_samples (int) – Posterior draws per chain.

  • num_chains (int) – Independent chains, run in parallel via vmap.

  • seed (int) – PRNG seed.

Examples

>>> import jax.numpy as jnp
>>> def logdensity(theta):
...     # standard-normal target
...     return -0.5 * jnp.sum(theta ** 2)
>>> sampler = BlackjaxSampler(num_warmup=500, num_samples=1000)
>>> samples = sampler.sample(logdensity, initial_position=jnp.zeros(3))
>>> samples.shape  # (num_samples, 3)
sample(logdensity_fn, initial_position, *, rng_key=None)[source]#

Run NUTS on a log-density function.

Parameters:
  • logdensity_fn (callable) – Maps a parameter pytree to a scalar log-density (unnormalised log-posterior). Must be JAX-traceable.

  • initial_position (pytree) – Starting position for a single chain. For num_chains > 1 it is broadcast across chains.

  • rng_key (jax.Array, optional) – PRNG key. Defaults to jax.random.PRNGKey(self.seed).

Returns:

pytree – Posterior draws. For a single chain each leaf has shape (num_samples, *param_shape); for multiple chains (num_chains, num_samples, *param_shape).

Return type:

Any

to_arviz(samples, *, var_names=None)[source]#

Convert posterior draws to ArviZ InferenceData.

Parameters:
  • samples (pytree) – Output of sample(). A dict maps variable names to draws; an array is wrapped under names from var_names (or "x").

  • var_names (sequence of str, optional) – Names for array-valued samples.

Returns:

arviz.InferenceData

Return type:

Any

name: str = 'blackjax'#
class spectralbrain.backends.gpu.CupyBackend(device_id=0)[source]#

Bases: object

GPU compute backend using CuPy.

Mirrors the NumpyBackend interface. Arrays live on the GPU; to_numpy() copies back to host.

Parameters:

device_id (int) – CUDA device index.

Examples

>>> be = CupyBackend(device_id=0)
>>> evals, evecs = be.eigsh(L, M, k=100)
>>> type(evals)  # cupy.ndarray — lives on GPU
argsort(x, axis=-1)[source]#

Indirect sort indices on GPU.

Parameters:
Return type:

Any

array(data, dtype=<class 'numpy.float64'>)[source]#

Create a CuPy array on GPU.

Parameters:
Return type:

Any

clip(x, a_min, a_max)[source]#

Element-wise clip on GPU.

Parameters:
Return type:

Any

concatenate(arrays, axis=0)[source]#

Concatenate CuPy arrays.

Parameters:
Return type:

Any

eigsh(L, M=None, k=100, *, sigma=-0.01, which='LM', tol=0.0, maxiter=None, dense_max=20000)[source]#

Smallest-k generalised eigenpairs L v = λ M v on the GPU.

CuPy’s sparse eigsh supports neither the generalised problem (no M) nor shift-invert (no sigma), and recovering the smallest Laplacian eigenvalues by plain Lanczos is unreliable. Because the FEM mass matrix is diagonal (lumped barycentric), the generalised problem standardises exactly to a symmetric one:

à = D^{-1/2} L D^{-1/2}, D = diag(M), ψ = D^{1/2} v

We solve the dense symmetric eigenproblem à ψ = λ ψ on the GPU (cupy.linalg.eigh — robust, no ARPACK convergence issues), keep the k smallest, and recover the M-orthonormal eigenvectors v = D^{-1/2} ψ. Validated against SciPy shift-invert and the analytic sphere spectrum. Meshes with N > dense_max fall back to CPU sparse shift-invert to avoid densification OOM. sigma/which/ tol/maxiter are honoured only on the fallback path; the signature mirrors NumpyBackend.eigsh(). Returns host arrays.

Parameters:
Return type:

tuple[ndarray[tuple[Any, …], dtype[floating]], ndarray[tuple[Any, …], dtype[floating]]]

exp(x)[source]#

Element-wise exponential on GPU.

Parameters:

x (Any)

Return type:

Any

eye(n, dtype=<class 'numpy.float64'>)[source]#

Create a GPU identity matrix.

Parameters:
Return type:

Any

linspace(start, stop, num)[source]#

Linearly spaced values on GPU.

Parameters:
Return type:

Any

log(x)[source]#

Element-wise safe log on GPU.

Parameters:

x (Any)

Return type:

Any

logspace(start, stop, num)[source]#

Log-spaced values on GPU.

Parameters:
Return type:

Any

matmul(a, b)[source]#

GPU matrix multiply.

Parameters:
Return type:

Any

mean(x, axis=None)[source]#

Mean reduction on GPU.

Parameters:
Return type:

Any

norm(x, axis=None, ord=None)[source]#

Vector/matrix norm on GPU.

Parameters:
Return type:

Any

ones(shape, dtype=<class 'numpy.float64'>)[source]#

Create a ones-filled CuPy array.

Parameters:
Return type:

Any

sparse_matrix(data, row, col, shape, **kwargs)[source]#

Build a sparse matrix from COO triplets on GPU.

Parameters:
Return type:

Any

sqrt(x)[source]#

Element-wise safe sqrt on GPU.

Parameters:

x (Any)

Return type:

Any

stack(arrays, axis=0)[source]#

Stack CuPy arrays along a new axis.

Parameters:
Return type:

Any

sum(x, axis=None)[source]#

Sum reduction on GPU.

Parameters:
Return type:

Any

to_numpy(x)[source]#

Copy GPU array to host.

Parameters:

x (Any)

Return type:

ndarray

zeros(shape, dtype=<class 'numpy.float64'>)[source]#

Create a zero-filled CuPy array.

Parameters:
Return type:

Any

name: str = 'cupy'#
class spectralbrain.backends.gpu.JaxBackend(device='gpu')[source]#

Bases: object

GPU backend using JAX with jit and vmap.

The key advantage of JAX over CuPy for SpectralBrain is jax.vmap(), which vectorises descriptor computation across an entire cohort without explicit loops, and jax.jit(), which compiles hot paths for reuse.

Parameters:

device (str) – "gpu" or "cpu".

Examples

>>> be = JaxBackend()
>>> # Batch HKS for 228 subjects:
>>> batched_hks = be.vmap(compute_hks)(all_evals, all_evecs, t)
argsort(x, axis=-1)[source]#

Indirect sort indices via JAX.

Parameters:
Return type:

Any

array(data, dtype=<class 'numpy.float64'>)[source]#

Create a JAX array.

Parameters:
Return type:

Any

clip(x, a_min, a_max)[source]#

Element-wise clip via JAX.

Parameters:
Return type:

Any

concatenate(arrays, axis=0)[source]#

Concatenate JAX arrays.

Parameters:
Return type:

Any

eigsh(L, M=None, k=100, **kwargs)[source]#

Sparse eigensolver via JAX’s LOBPCG.

For the generalised problem L v = λ M v, falls back to SciPy ARPACK on host and transfers results — JAX’s sparse eigensolver does not yet support generalised problems natively. The eigenvalues / vectors are returned as NumPy.

Parameters:
  • L (same as NumpyBackend.eigsh)

  • M (same as NumpyBackend.eigsh)

  • k (same as NumpyBackend.eigsh)

  • kwargs (Any)

Returns:

eigenvalues, eigenvectors (NumPy arrays.)

Return type:

tuple[ndarray[tuple[Any, …], dtype[floating]], ndarray[tuple[Any, …], dtype[floating]]]

exp(x)[source]#

Element-wise exponential via JAX.

Parameters:

x (Any)

Return type:

Any

eye(n, dtype=<class 'numpy.float64'>)[source]#

Create a JAX identity matrix.

Parameters:
Return type:

Any

jit(func, **kwargs)[source]#

JIT-compile a function.

Parameters:
  • func (callable) – Pure function (no side effects).

  • kwargs (Any)

Returns:

callable – JIT-compiled version.

Return type:

Callable

linspace(start, stop, num)[source]#

Linearly spaced values via JAX.

Parameters:
Return type:

Any

log(x)[source]#

Element-wise safe log via JAX.

Parameters:

x (Any)

Return type:

Any

logspace(start, stop, num)[source]#

Log-spaced values via JAX.

Parameters:
Return type:

Any

matmul(a, b)[source]#

JAX matrix multiply.

Parameters:
Return type:

Any

mean(x, axis=None)[source]#

Mean reduction via JAX.

Parameters:
Return type:

Any

norm(x, axis=None, ord=None)[source]#

Vector/matrix norm via JAX.

Parameters:
Return type:

Any

ones(shape, dtype=<class 'numpy.float64'>)[source]#

Create a ones-filled JAX array.

Parameters:
Return type:

Any

sqrt(x)[source]#

Element-wise safe sqrt via JAX.

Parameters:

x (Any)

Return type:

Any

stack(arrays, axis=0)[source]#

Stack JAX arrays along a new axis.

Parameters:
Return type:

Any

sum(x, axis=None)[source]#

Sum reduction via JAX.

Parameters:
Return type:

Any

to_numpy(x)[source]#

Transfer JAX array to NumPy.

Parameters:

x (Any)

Return type:

ndarray

vmap(func, in_axes=0, out_axes=0)[source]#

Auto-vectorise func over a batch axis.

Parameters:
  • func (callable) – Function operating on a single example.

  • in_axes (int or tuple) – Which axes of each argument to vectorise over.

  • out_axes (int or tuple) – Output batch axis.

Returns:

callable – Batched version of func.

Return type:

Callable

Examples

>>> # Single-subject HKS: (k,), (N, k), (T,) → (N, T)
>>> batched = be.vmap(compute_hks)
>>> # Now: (S, k), (S, N, k), (T,) → (S, N, T)
>>> all_hks = batched(all_evals, all_evecs, t_values)
zeros(shape, dtype=<class 'numpy.float64'>)[source]#

Create a zero-filled JAX array.

Parameters:
Return type:

Any

name: str = 'jax'#
class spectralbrain.backends.gpu.NumPyroSampler(num_warmup=1000, num_samples=2000, num_chains=4, seed=42)[source]#

Bases: object

GPU-accelerated Bayesian MCMC using NumPyro + JAX.

NumPyro runs NUTS on XLA-compiled JAX graphs, achieving substantial speedups over PyMC on GPU for models with many parameters (e.g. hierarchical normative models with thousands of vertex-level effects).

Parameters:
  • num_warmup (int) – Warmup (tuning) samples.

  • num_samples (int) – Posterior draws.

  • num_chains (int) – Independent chains.

  • seed (int) – PRNG seed.

Examples

>>> sampler = NumPyroSampler(num_warmup=500, num_samples=2000)
>>> # Define a NumPyro model function:
>>> def model(x, y=None):
...     alpha = numpyro.sample("alpha", dist.Normal(0, 1))
...     sigma = numpyro.sample("sigma", dist.HalfNormal(1))
...     mu = alpha * x
...     numpyro.sample("obs", dist.Normal(mu, sigma), obs=y)
>>> trace = sampler.sample(model, x=x_data, y=y_data)
sample(model, **model_kwargs)[source]#

Run NUTS on a NumPyro model function.

Parameters:
  • model (callable) – A NumPyro model function.

  • **model_kwargs – Data and hyperparameters passed to model.

Returns:

numpyro.infer.MCMC – MCMC object with .get_samples() and .print_summary().

Return type:

Any

to_arviz(mcmc)[source]#

Convert NumPyro MCMC to ArviZ InferenceData.

Parameters:

mcmc (numpyro.infer.MCMC)

Returns:

arviz.InferenceData

Return type:

Any

name: str = 'numpyro'#
class spectralbrain.backends.gpu.TorchBackend(device='cuda')[source]#

Bases: object

GPU compute backend using PyTorch.

Mirrors the NumpyBackend interface so it can be passed to BrainMesh.decompose() via backend=. Dense ops run as Torch tensors on the selected device; to_numpy() copies back to host.

PyTorch has no robust sparse generalised eigensolver, so eigsh() uses the same diagonal-mass standardisation as CupyBackendà = D^{-1/2} L D^{-1/2} solved with the dense torch.linalg.eigh on the device — and falls back to CPU sparse shift-invert for meshes above dense_max.

Parameters:

device (str) – "cuda" or "cpu". If "cuda" is requested but no GPU is available, the backend falls back to CPU.

Examples

>>> be = TorchBackend()
>>> evals, evecs = be.eigsh(L, M, k=100)  # host NumPy arrays
argsort(x, axis=-1)[source]#

Indices that sort the tensor.

Parameters:
Return type:

Any

array(data, dtype=<class 'numpy.float64'>)[source]#

Create a Torch tensor on the device.

Parameters:
Return type:

Any

clip(x, a_min, a_max)[source]#

Element-wise clamp.

Parameters:
Return type:

Any

concatenate(arrays, axis=0)[source]#

Concatenate along an existing axis.

Parameters:
Return type:

Any

eigsh(L, M=None, k=100, *, sigma=-0.01, which='LM', tol=0.0, maxiter=None, dense_max=20000)[source]#

Smallest-k generalised eigenpairs L v = λ M v on the device.

Uses the diagonal-mass standardisation à = D^{-1/2} L D^{-1/2} with a dense torch.linalg.eigh on the device, keeps the k smallest, and recovers the M-orthonormal eigenvectors v = D^{-1/2} ψ. Meshes with N > dense_max fall back to CPU sparse shift-invert to avoid densification OOM. sigma/which/ tol/maxiter are honoured only on that fallback path. Returns host (NumPy) arrays, matching NumpyBackend.eigsh().

Parameters:
Return type:

tuple[ndarray[tuple[Any, …], dtype[floating]], ndarray[tuple[Any, …], dtype[floating]]]

exp(x)[source]#

Element-wise exponential.

Parameters:

x (Any)

Return type:

Any

eye(n, dtype=<class 'numpy.float64'>)[source]#

Create an identity tensor.

Parameters:
Return type:

Any

jit(func, **kwargs)[source]#

Compile a function with torch.compile (Torch-native ops only).

Parameters:
Return type:

Callable

linspace(start, stop, num)[source]#

Evenly spaced values.

Parameters:
Return type:

Any

log(x)[source]#

Element-wise safe log.

Parameters:

x (Any)

Return type:

Any

logspace(start, stop, num)[source]#

Log-spaced values.

Parameters:
Return type:

Any

matmul(a, b)[source]#

Matrix multiply.

Parameters:
Return type:

Any

mean(x, axis=None)[source]#

Mean reduction.

Parameters:
Return type:

Any

norm(x, axis=None, ord=None)[source]#

Vector/matrix norm.

Parameters:
Return type:

Any

ones(shape, dtype=<class 'numpy.float64'>)[source]#

Create a ones-filled tensor.

Parameters:
Return type:

Any

sparse_matrix(data, row, col, shape, **kwargs)[source]#

Build a sparse COO tensor from triplets on the device.

Parameters:
Return type:

Any

sqrt(x)[source]#

Element-wise safe sqrt.

Parameters:

x (Any)

Return type:

Any

stack(arrays, axis=0)[source]#

Stack along a new axis.

Parameters:
Return type:

Any

sum(x, axis=None)[source]#

Sum reduction.

Parameters:
Return type:

Any

to_numpy(x)[source]#

Copy a device tensor to host.

Parameters:

x (Any)

Return type:

ndarray

vmap(func, **kwargs)[source]#

Vectorise a function with torch.func.vmap (Torch-native ops only).

Parameters:
Return type:

Callable

zeros(shape, dtype=<class 'numpy.float64'>)[source]#

Create a zero-filled tensor.

Parameters:
Return type:

Any

name: str = 'torch'#
class spectralbrain.backends.gpu.VRAMInfo(device_name, device_id, total_gb, used_gb, free_gb, percent_used)[source]#

Bases: object

Snapshot of GPU VRAM usage.

Parameters:
device_name#

GPU model name.

Type:

str

device_id#

CUDA device index.

Type:

int

total_gb#

Total VRAM.

Type:

float

used_gb#

Currently allocated VRAM.

Type:

float

free_gb#

Available VRAM.

Type:

float

percent_used#

Usage percentage.

Type:

float

device_id: int#
device_name: str#
free_gb: float#
percent_used: float#
total_gb: float#
used_gb: float#
spectralbrain.backends.gpu.get_gpu_backend(name='cupy', **kwargs)[source]#

Factory for GPU compute backends.

Parameters:
  • name ("cupy", "jax", or "torch")

  • **kwargs – Passed to the backend constructor.

Returns:

CupyBackend, JaxBackend, or TorchBackend

Return type:

CupyBackend | JaxBackend | TorchBackend

spectralbrain.backends.gpu.get_gpu_bayesian_sampler(backend='numpyro', **kwargs)[source]#

Factory for GPU Bayesian samplers.

Parameters:
  • backend ("numpyro" or "blackjax") – Which JAX-based sampler to use. NumPyro takes a model function; BlackJAX takes a log-density function.

  • **kwargs – Passed to the sampler constructor (num_warmup, num_samples, num_chains, seed).

Returns:

NumPyroSampler or BlackjaxSampler

Return type:

NumPyroSampler | BlackjaxSampler

spectralbrain.backends.gpu.vram_clear()[source]#

Clear CUDA memory caches across all known GPU frameworks.

Calls cache-clearing functions for CuPy, JAX, and PyTorch (if installed). Safe to call even when no GPU framework is loaded.

Return type:

None

spectralbrain.backends.gpu.vram_defrag(device_id=0)[source]#

Attempt CUDA memory defragmentation.

Frees cached blocks and triggers a synchronisation barrier that allows the CUDA driver to consolidate fragmented allocations.

Parameters:

device_id (int) – CUDA device index.

Return type:

None

Notes

True defragmentation is limited by CUDA’s memory model — once an allocation is placed, it cannot be moved. This function does the best available: free caches, synchronise, and let the driver reclaim contiguous regions.

spectralbrain.backends.gpu.vram_gc(device_id=0)[source]#

Full GPU garbage collection: Python GC + VRAM clear + defrag.

Parameters:

device_id (int) – CUDA device index.

Return type:

None

spectralbrain.backends.gpu.vram_guard(min_free_gb=1.0, device_id=0, error_on_low=False, auto_clear=True)[source]#

Context manager for VRAM-safe GPU operations.

Checks available VRAM before the block. If below threshold, optionally clears caches or raises an error. Reports delta after the block completes.

Parameters:
  • min_free_gb (float) – Minimum required free VRAM.

  • device_id (int) – CUDA device index.

  • error_on_low (bool) – Raise MemoryError if VRAM is insufficient.

  • auto_clear (bool) – Attempt to clear caches if VRAM is low.

Return type:

Generator[None, None, None]

Examples

>>> with vram_guard(min_free_gb=4.0):
...     result = gpu_heavy_computation()
spectralbrain.backends.gpu.vram_monitor(device_id=0, label='')[source]#

Log current VRAM usage (one-shot, for debugging).

Parameters:
  • device_id (int) – CUDA device index.

  • label (str) – Optional context label for the log message.

Return type:

None

Examples

>>> vram_monitor(label="after eigsolve")
# GPU 0 (RTX 3090): 3.42 / 24.00 GB (14%) [after eigsolve]
spectralbrain.backends.gpu.vram_status(device_id=0)[source]#

Query current VRAM usage.

Tries CuPy first, then nvidia-smi as fallback.

Parameters:

device_id (int) – CUDA device index.

Returns:

VRAMInfo

Return type:

VRAMInfo