spectralbrain.backends.gpu#
GPU compute backends — CuPy, JAX, NumPyro, and VRAM management.
This module provides:
CupyBackend — drop-in GPU replacement for
NumpyBackendusing CuPy (cupy-cuda13x).JaxBackend — GPU backend with
jitandvmapfor batch-subject spectral descriptor computation.NumPyroSampler — GPU-accelerated Bayesian MCMC via JAX.
VRAM management — monitoring, cache clearing, defragmentation, garbage collection, and a memory-guarded context manager.
All dependencies are lazy-imported. If CuPy or JAX is missing, the
module still imports successfully — only instantiation of the backends
raises ImportError.
Functions
|
Factory for GPU compute backends. |
|
Factory for GPU Bayesian samplers. |
Clear CUDA memory caches across all known GPU frameworks. |
|
|
Attempt CUDA memory defragmentation. |
|
Full GPU garbage collection: Python GC + VRAM clear + defrag. |
|
Context manager for VRAM-safe GPU operations. |
|
Log current VRAM usage (one-shot, for debugging). |
|
Query current VRAM usage. |
Classes
|
GPU-accelerated Bayesian NUTS via BlackJAX. |
|
GPU compute backend using CuPy. |
|
GPU backend using JAX with |
|
GPU-accelerated Bayesian MCMC using NumPyro + JAX. |
|
GPU compute backend using PyTorch. |
|
Snapshot of GPU VRAM usage. |
- class spectralbrain.backends.gpu.BlackjaxSampler(num_warmup=1000, num_samples=2000, num_chains=4, seed=42)[source]#
Bases:
objectGPU-accelerated Bayesian NUTS via BlackJAX.
BlackJAX is a low-level sampler that operates on a log-density function rather than a model object, which makes it composable and fast under JAX
jit/vmapon the GPU. This wrapper runs the standard window-adaptation → NUTS pipeline and (optionally) vectorises independent chains withjax.vmap().- Parameters:
Examples
>>> import jax.numpy as jnp >>> def logdensity(theta): ... # standard-normal target ... return -0.5 * jnp.sum(theta ** 2) >>> sampler = BlackjaxSampler(num_warmup=500, num_samples=1000) >>> samples = sampler.sample(logdensity, initial_position=jnp.zeros(3)) >>> samples.shape # (num_samples, 3)
- sample(logdensity_fn, initial_position, *, rng_key=None)[source]#
Run NUTS on a log-density function.
- Parameters:
logdensity_fn (callable) – Maps a parameter pytree to a scalar log-density (unnormalised log-posterior). Must be JAX-traceable.
initial_position (pytree) – Starting position for a single chain. For
num_chains > 1it is broadcast across chains.rng_key (jax.Array, optional) – PRNG key. Defaults to
jax.random.PRNGKey(self.seed).
- Returns:
pytree – Posterior draws. For a single chain each leaf has shape
(num_samples, *param_shape); for multiple chains(num_chains, num_samples, *param_shape).- Return type:
- class spectralbrain.backends.gpu.CupyBackend(device_id=0)[source]#
Bases:
objectGPU compute backend using CuPy.
Mirrors the
NumpyBackendinterface. Arrays live on the GPU;to_numpy()copies back to host.- Parameters:
device_id (int) – CUDA device index.
Examples
>>> be = CupyBackend(device_id=0) >>> evals, evecs = be.eigsh(L, M, k=100) >>> type(evals) # cupy.ndarray — lives on GPU
- eigsh(L, M=None, k=100, *, sigma=-0.01, which='LM', tol=0.0, maxiter=None, dense_max=20000)[source]#
Smallest-k generalised eigenpairs
L v = λ M von the GPU.CuPy’s sparse
eigshsupports neither the generalised problem (noM) nor shift-invert (nosigma), and recovering the smallest Laplacian eigenvalues by plain Lanczos is unreliable. Because the FEM mass matrix is diagonal (lumped barycentric), the generalised problem standardises exactly to a symmetric one:Ã = D^{-1/2} L D^{-1/2}, D = diag(M), ψ = D^{1/2} v
We solve the dense symmetric eigenproblem
à ψ = λ ψon the GPU (cupy.linalg.eigh— robust, no ARPACK convergence issues), keep theksmallest, and recover the M-orthonormal eigenvectorsv = D^{-1/2} ψ. Validated against SciPy shift-invert and the analytic sphere spectrum. Meshes withN > dense_maxfall back to CPU sparse shift-invert to avoid densification OOM.sigma/which/tol/maxiterare honoured only on the fallback path; the signature mirrorsNumpyBackend.eigsh(). Returns host arrays.
- sparse_matrix(data, row, col, shape, **kwargs)[source]#
Build a sparse matrix from COO triplets on GPU.
- class spectralbrain.backends.gpu.JaxBackend(device='gpu')[source]#
Bases:
objectGPU backend using JAX with
jitandvmap.The key advantage of JAX over CuPy for SpectralBrain is
jax.vmap(), which vectorises descriptor computation across an entire cohort without explicit loops, andjax.jit(), which compiles hot paths for reuse.- Parameters:
device (str) –
"gpu"or"cpu".
Examples
>>> be = JaxBackend() >>> # Batch HKS for 228 subjects: >>> batched_hks = be.vmap(compute_hks)(all_evals, all_evecs, t)
- eigsh(L, M=None, k=100, **kwargs)[source]#
Sparse eigensolver via JAX’s LOBPCG.
For the generalised problem L v = λ M v, falls back to SciPy ARPACK on host and transfers results — JAX’s sparse eigensolver does not yet support generalised problems natively. The eigenvalues / vectors are returned as NumPy.
- vmap(func, in_axes=0, out_axes=0)[source]#
Auto-vectorise func over a batch axis.
- Parameters:
- Returns:
callable – Batched version of func.
- Return type:
Examples
>>> # Single-subject HKS: (k,), (N, k), (T,) → (N, T) >>> batched = be.vmap(compute_hks) >>> # Now: (S, k), (S, N, k), (T,) → (S, N, T) >>> all_hks = batched(all_evals, all_evecs, t_values)
- class spectralbrain.backends.gpu.NumPyroSampler(num_warmup=1000, num_samples=2000, num_chains=4, seed=42)[source]#
Bases:
objectGPU-accelerated Bayesian MCMC using NumPyro + JAX.
NumPyro runs NUTS on XLA-compiled JAX graphs, achieving substantial speedups over PyMC on GPU for models with many parameters (e.g. hierarchical normative models with thousands of vertex-level effects).
- Parameters:
Examples
>>> sampler = NumPyroSampler(num_warmup=500, num_samples=2000) >>> # Define a NumPyro model function: >>> def model(x, y=None): ... alpha = numpyro.sample("alpha", dist.Normal(0, 1)) ... sigma = numpyro.sample("sigma", dist.HalfNormal(1)) ... mu = alpha * x ... numpyro.sample("obs", dist.Normal(mu, sigma), obs=y) >>> trace = sampler.sample(model, x=x_data, y=y_data)
- sample(model, **model_kwargs)[source]#
Run NUTS on a NumPyro model function.
- Parameters:
model (callable) – A NumPyro model function.
**model_kwargs – Data and hyperparameters passed to model.
- Returns:
numpyro.infer.MCMC – MCMC object with
.get_samples()and.print_summary().- Return type:
- class spectralbrain.backends.gpu.TorchBackend(device='cuda')[source]#
Bases:
objectGPU compute backend using PyTorch.
Mirrors the
NumpyBackendinterface so it can be passed toBrainMesh.decompose()viabackend=. Dense ops run as Torch tensors on the selected device;to_numpy()copies back to host.PyTorch has no robust sparse generalised eigensolver, so
eigsh()uses the same diagonal-mass standardisation asCupyBackend—Ã = D^{-1/2} L D^{-1/2}solved with the densetorch.linalg.eighon the device — and falls back to CPU sparse shift-invert for meshes abovedense_max.- Parameters:
device (str) –
"cuda"or"cpu". If"cuda"is requested but no GPU is available, the backend falls back to CPU.
Examples
>>> be = TorchBackend() >>> evals, evecs = be.eigsh(L, M, k=100) # host NumPy arrays
- eigsh(L, M=None, k=100, *, sigma=-0.01, which='LM', tol=0.0, maxiter=None, dense_max=20000)[source]#
Smallest-k generalised eigenpairs
L v = λ M von the device.Uses the diagonal-mass standardisation
à = D^{-1/2} L D^{-1/2}with a densetorch.linalg.eighon the device, keeps theksmallest, and recovers the M-orthonormal eigenvectorsv = D^{-1/2} ψ. Meshes withN > dense_maxfall back to CPU sparse shift-invert to avoid densification OOM.sigma/which/tol/maxiterare honoured only on that fallback path. Returns host (NumPy) arrays, matchingNumpyBackend.eigsh().
- sparse_matrix(data, row, col, shape, **kwargs)[source]#
Build a sparse COO tensor from triplets on the device.
- class spectralbrain.backends.gpu.VRAMInfo(device_name, device_id, total_gb, used_gb, free_gb, percent_used)[source]#
Bases:
objectSnapshot of GPU VRAM usage.
- Parameters:
- spectralbrain.backends.gpu.get_gpu_backend(name='cupy', **kwargs)[source]#
Factory for GPU compute backends.
- Parameters:
name (
"cupy","jax", or"torch")**kwargs – Passed to the backend constructor.
- Returns:
CupyBackend, JaxBackend, or TorchBackend
- Return type:
- spectralbrain.backends.gpu.get_gpu_bayesian_sampler(backend='numpyro', **kwargs)[source]#
Factory for GPU Bayesian samplers.
- Parameters:
backend (
"numpyro"or"blackjax") – Which JAX-based sampler to use. NumPyro takes a model function; BlackJAX takes a log-density function.**kwargs – Passed to the sampler constructor (
num_warmup,num_samples,num_chains,seed).
- Returns:
NumPyroSampler or BlackjaxSampler
- Return type:
- spectralbrain.backends.gpu.vram_clear()[source]#
Clear CUDA memory caches across all known GPU frameworks.
Calls cache-clearing functions for CuPy, JAX, and PyTorch (if installed). Safe to call even when no GPU framework is loaded.
- Return type:
None
- spectralbrain.backends.gpu.vram_defrag(device_id=0)[source]#
Attempt CUDA memory defragmentation.
Frees cached blocks and triggers a synchronisation barrier that allows the CUDA driver to consolidate fragmented allocations.
- Parameters:
device_id (int) – CUDA device index.
- Return type:
None
Notes
True defragmentation is limited by CUDA’s memory model — once an allocation is placed, it cannot be moved. This function does the best available: free caches, synchronise, and let the driver reclaim contiguous regions.
- spectralbrain.backends.gpu.vram_gc(device_id=0)[source]#
Full GPU garbage collection: Python GC + VRAM clear + defrag.
- Parameters:
device_id (int) – CUDA device index.
- Return type:
None
- spectralbrain.backends.gpu.vram_guard(min_free_gb=1.0, device_id=0, error_on_low=False, auto_clear=True)[source]#
Context manager for VRAM-safe GPU operations.
Checks available VRAM before the block. If below threshold, optionally clears caches or raises an error. Reports delta after the block completes.
- Parameters:
- Return type:
Generator[None, None, None]
Examples
>>> with vram_guard(min_free_gb=4.0): ... result = gpu_heavy_computation()