Source code for spectralbrain.backends.cpu

"""CPU compute backend — NumPy/SciPy, Bayesian samplers, and parallelisation.

This module provides:

1. **NumpyBackend** — the default compute engine using SciPy sparse
   eigensolvers (ARPACK), sparse matrix operations, and NumPy array
   algebra.  Every other backend mirrors this interface.
2. **PyMCSampler / NutpieSampler** — Bayesian MCMC backends for the
   ``statistics/bayesian.py`` module.
3. **Joblib utilities** — composable parallelisation helpers with
   Rich progress integration.
4. **RAM management** — memory monitoring, garbage collection, and
   estimation helpers for multi-subject pipelines.

All optional dependencies (PyMC, nutpie, joblib) are lazy-imported.
Only NumPy and SciPy are hard requirements.
"""

from __future__ import annotations

import gc
from collections.abc import Callable, Generator, Iterator, Sequence
from contextlib import contextmanager
from dataclasses import dataclass
from pathlib import Path
from typing import (
    Any,
    Literal,
    TypeVar,
)

import numpy as np
import scipy.sparse as sp
import scipy.sparse.linalg as spla

from spectralbrain.runtime import (
    Eigenvalues,
    Eigenvectors,
    MassMatrix,
    SparseMatrix,
    get_logger,
    progress_parallel,
)

logger = get_logger(__name__)

T = TypeVar("T")
R = TypeVar("R")


# ======================================================================
# §1  NUMPY / SCIPY COMPUTE BACKEND
# ======================================================================


[docs] class NumpyBackend: """CPU compute backend using NumPy + SciPy. Provides the canonical interface that :class:`CupyBackend` and :class:`JaxBackend` mirror. All ``core/`` and ``spectral/`` modules call backend methods rather than importing NumPy or SciPy directly, enabling transparent GPU acceleration. Examples -------- >>> from spectralbrain.backends.cpu import NumpyBackend >>> be = NumpyBackend() >>> evals, evecs = be.eigsh(L, M, k=100) >>> hks = be.exp(-evals[None, :] * t[:, None]) # broadcasting """ name: str = "numpy" # ── Sparse eigensolvers ───────────────────────────────────────────
[docs] @staticmethod def eigsh( L: SparseMatrix, M: MassMatrix | None = None, k: int = 100, *, sigma: float = -0.01, which: str = "LM", tol: float = 0.0, maxiter: int | None = None, ) -> tuple[Eigenvalues, Eigenvectors]: """Solve the generalised sparse eigenproblem L v = λ M v. Uses SciPy's ARPACK wrapper in shift-invert mode (default σ = −0.01) which is optimal for computing the *smallest* eigenvalues of the Laplacian. Parameters ---------- L : sparse matrix, shape (N, N) Stiffness (Laplacian) matrix — symmetric positive semi-definite. M : sparse matrix, shape (N, N), optional Mass matrix. If ``None``, the standard eigenproblem L v = λ v is solved. k : int Number of eigenpairs to compute. sigma : float Shift for shift-invert mode. A small negative value avoids the singularity at λ = 0. which : str Which eigenvalues to target (``"LM"`` = largest magnitude *of the shifted operator*, yielding the smallest λ). tol : float Convergence tolerance (0 = machine precision). maxiter : int, optional Maximum ARPACK iterations. Returns ------- eigenvalues : ndarray, shape (k,) Sorted ascending, float64. eigenvectors : ndarray, shape (N, k) Corresponding eigenvectors, M-orthonormal. Raises ------ scipy.sparse.linalg.ArpackNoConvergence If ARPACK fails to converge within *maxiter* iterations. """ L = sp.csc_matrix(L, dtype=np.float64) if M is not None: M = sp.csc_matrix(M, dtype=np.float64) eigenvalues, eigenvectors = spla.eigsh( L, k=k, M=M, sigma=sigma, which=which, tol=tol, maxiter=maxiter, ) # Sort ascending (ARPACK returns in arbitrary order after # shift-invert). order = np.argsort(eigenvalues) eigenvalues = eigenvalues[order] eigenvectors = eigenvectors[:, order] # Clamp tiny negative eigenvalues from numerical noise. eigenvalues = np.clip(eigenvalues, 0.0, None) return eigenvalues, eigenvectors
# ── Sparse matrix construction ────────────────────────────────────
[docs] @staticmethod def sparse_matrix( data: np.ndarray, row: np.ndarray, col: np.ndarray, shape: tuple[int, int], *, format: str = "csc", ) -> SparseMatrix: """Build a sparse matrix from COO triplets. Parameters ---------- data : ndarray Non-zero values. row, col : ndarray Row and column indices. shape : (int, int) Matrix dimensions. format : str Output format (``"csc"``, ``"csr"``, ``"coo"``). Returns ------- SparseMatrix """ coo = sp.coo_matrix( ( np.asarray(data, dtype=np.float64), (np.asarray(row, dtype=np.int64), np.asarray(col, dtype=np.int64)), ), shape=shape, ) if format == "csc": return coo.tocsc() elif format == "csr": return coo.tocsr() return coo
# ── Dense array operations ──────────────────────────────────────── # These thin wrappers exist so that CupyBackend / JaxBackend can # override them transparently.
[docs] @staticmethod def array(data: Any, dtype: np.dtype = np.float64) -> np.ndarray: """Create a dense array.""" return np.asarray(data, dtype=dtype)
[docs] @staticmethod def zeros(shape: tuple[int, ...], dtype: np.dtype = np.float64) -> np.ndarray: """Create a zero-filled array (mirrors numpy.zeros).""" return np.zeros(shape, dtype=dtype)
[docs] @staticmethod def ones(shape: tuple[int, ...], dtype: np.dtype = np.float64) -> np.ndarray: """Create a ones-filled array (mirrors numpy.ones).""" return np.ones(shape, dtype=dtype)
[docs] @staticmethod def eye(n: int, dtype: np.dtype = np.float64) -> np.ndarray: """Create an identity matrix (mirrors numpy.eye).""" return np.eye(n, dtype=dtype)
[docs] @staticmethod def matmul(a: np.ndarray, b: np.ndarray) -> np.ndarray: """Matrix multiply (sparse- and dense-aware).""" if sp.issparse(a) or sp.issparse(b): return a @ b return np.matmul(a, b)
[docs] @staticmethod def exp(x: np.ndarray) -> np.ndarray: """Element-wise exponential (mirrors numpy.exp).""" return np.exp(x)
[docs] @staticmethod def log(x: np.ndarray) -> np.ndarray: """Element-wise safe log with clamp at 1e-300.""" return np.log(np.clip(x, 1e-300, None))
[docs] @staticmethod def sqrt(x: np.ndarray) -> np.ndarray: """Element-wise safe sqrt with clamp at 0.""" return np.sqrt(np.clip(x, 0.0, None))
[docs] @staticmethod def sum(x: np.ndarray, axis: int | None = None) -> np.ndarray: """Sum reduction (mirrors numpy.sum).""" return np.sum(x, axis=axis)
[docs] @staticmethod def mean(x: np.ndarray, axis: int | None = None) -> np.ndarray: """Mean reduction (mirrors numpy.mean).""" return np.mean(x, axis=axis)
[docs] @staticmethod def clip(x: np.ndarray, a_min: float | None, a_max: float | None) -> np.ndarray: """Element-wise clip (mirrors numpy.clip).""" return np.clip(x, a_min, a_max)
[docs] @staticmethod def to_numpy(x: Any) -> np.ndarray: """Convert any array-like to a NumPy ndarray.""" if isinstance(x, np.ndarray): return x if sp.issparse(x): return x.toarray() return np.asarray(x)
[docs] @staticmethod def norm(x: np.ndarray, axis: int | None = None, ord: int | None = None) -> np.ndarray: """Vector/matrix norm (mirrors numpy.linalg.norm).""" return np.linalg.norm(x, axis=axis, ord=ord)
[docs] @staticmethod def argsort(x: np.ndarray, axis: int = -1) -> np.ndarray: """Indirect sort indices (mirrors numpy.argsort).""" return np.argsort(x, axis=axis)
[docs] @staticmethod def concatenate(arrays: Sequence[np.ndarray], axis: int = 0) -> np.ndarray: """Concatenate arrays along an axis.""" return np.concatenate(arrays, axis=axis)
[docs] @staticmethod def stack(arrays: Sequence[np.ndarray], axis: int = 0) -> np.ndarray: """Stack arrays along a new axis.""" return np.stack(arrays, axis=axis)
[docs] @staticmethod def linspace(start: float, stop: float, num: int) -> np.ndarray: """Linearly spaced values (mirrors numpy.linspace).""" return np.linspace(start, stop, num, dtype=np.float64)
[docs] @staticmethod def logspace(start: float, stop: float, num: int) -> np.ndarray: """Log-spaced values (mirrors numpy.logspace).""" return np.logspace(start, stop, num, dtype=np.float64)
# ====================================================================== # §2 BAYESIAN CPU SAMPLERS # ====================================================================== def _require_pymc(): """Lazy-import PyMC, raising ImportError if unavailable.""" try: import pymc as pm return pm except ImportError as exc: raise ImportError("PyMC is required for Bayesian analysis.\n pip install pymc") from exc def _require_nutpie(): """Lazy-import nutpie, raising ImportError if unavailable.""" try: import nutpie return nutpie except ImportError as exc: raise ImportError( "nutpie is required for the nutpie sampler backend.\n pip install nutpie" ) from exc def _require_arviz(): """Lazy-import ArviZ, raising ImportError if unavailable.""" try: import arviz as az return az except ImportError as exc: raise ImportError( "ArviZ is required for Bayesian diagnostics.\n pip install arviz" ) from exc
[docs] @dataclass class SamplerConfig: """Configuration for Bayesian MCMC samplers. Parameters ---------- draws : int Number of posterior draws per chain. tune : int Number of tuning (burn-in) samples. chains : int Number of independent chains. cores : int CPU cores for parallel chains. target_accept : float Target acceptance probability for NUTS. random_seed : int or None RNG seed for reproducibility. """ draws: int = 2000 tune: int = 1000 chains: int = 4 cores: int = 4 target_accept: float = 0.95 random_seed: int | None = 42
[docs] class PyMCSampler: """Bayesian sampler using PyMC's native NUTS implementation. This is the default CPU sampler. It wraps ``pymc.sample()`` with SpectralBrain-compatible configuration and logging. Parameters ---------- config : SamplerConfig, optional Sampling configuration. Examples -------- >>> sampler = PyMCSampler(SamplerConfig(draws=1000, chains=2)) >>> with pm.Model() as model: ... mu = pm.Normal("mu", 0, 1) ... obs = pm.Normal("obs", mu, 1, observed=data) >>> trace = sampler.sample(model) """ name: str = "nuts" def __init__(self, config: SamplerConfig | None = None) -> None: """Initialise with optional SamplerConfig.""" self.config = config or SamplerConfig()
[docs] def sample( self, model: Any, # pm.Model **kwargs: Any, ) -> Any: # az.InferenceData """Run NUTS sampling on a PyMC model. Parameters ---------- model : pymc.Model A fully specified PyMC model. **kwargs Overrides passed to ``pymc.sample()``. Returns ------- arviz.InferenceData Posterior samples with diagnostics. """ pm = _require_pymc() cfg = self.config sample_kwargs = dict( draws=cfg.draws, tune=cfg.tune, chains=cfg.chains, cores=cfg.cores, target_accept=cfg.target_accept, random_seed=cfg.random_seed, return_inferencedata=True, progressbar=True, ) sample_kwargs.update(kwargs) logger.info( "PyMC NUTS: %d draws × %d chains (%d tune)", cfg.draws, cfg.chains, cfg.tune, ) with model: trace = pm.sample(**sample_kwargs) return trace
[docs] class NutpieSampler: """Bayesian sampler using nutpie (Rust-based NUTS). nutpie is a high-performance drop-in replacement for PyMC's default sampler. It compiles the PyMC model to Rust and runs NUTS 2–10× faster on CPU. Parameters ---------- config : SamplerConfig, optional Sampling configuration. Examples -------- >>> sampler = NutpieSampler(SamplerConfig(draws=2000)) >>> trace = sampler.sample(model) """ name: str = "nutpie" def __init__(self, config: SamplerConfig | None = None) -> None: """Initialise with optional SamplerConfig.""" self.config = config or SamplerConfig()
[docs] def sample( self, model: Any, # pm.Model **kwargs: Any, ) -> Any: # az.InferenceData """Run nutpie NUTS on a PyMC model. Parameters ---------- model : pymc.Model A fully specified PyMC model. **kwargs Overrides passed to ``nutpie.sample()``. Returns ------- arviz.InferenceData """ nutpie = _require_nutpie() cfg = self.config logger.info( "nutpie NUTS: %d draws × %d chains (%d tune)", cfg.draws, cfg.chains, cfg.tune, ) compiled = nutpie.compile_pymc_model(model) trace = nutpie.sample( compiled, draws=cfg.draws, tune=cfg.tune, chains=cfg.chains, seed=cfg.random_seed, progress_bar=True, **kwargs, ) return trace
[docs] def get_bayesian_sampler( backend: Literal["nuts", "nutpie"] = "nuts", config: SamplerConfig | None = None, ) -> PyMCSampler | NutpieSampler: """Factory for CPU Bayesian samplers. Parameters ---------- backend : ``"nuts"`` or ``"nutpie"`` Which sampler to use. config : SamplerConfig, optional Sampling parameters. Returns ------- PyMCSampler or NutpieSampler """ if backend == "nuts": return PyMCSampler(config) elif backend == "nutpie": return NutpieSampler(config) else: raise ValueError(f"Unknown CPU Bayesian backend: {backend!r}")
# ====================================================================== # §3 JOBLIB PARALLELISATION UTILITIES # ====================================================================== def _require_joblib(): """Lazy-import joblib, raising ImportError if unavailable.""" try: import joblib return joblib except ImportError as exc: raise ImportError("joblib is required for parallelisation.\n pip install joblib") from exc
[docs] def parallel_map( func: Callable[..., R], items: Sequence[T], *, n_jobs: int = -1, backend: str = "loky", progress: bool = True, description: str = "Processing", **func_kwargs: Any, ) -> list[R]: """Apply *func* to each item in parallel with optional progress. A thin wrapper around ``joblib.Parallel`` that integrates with SpectralBrain's Rich progress bars. Parameters ---------- func : callable Function to apply. Must accept each item as its first positional argument. items : sequence Items to process. n_jobs : int Number of parallel workers (``-1`` = all cores). backend : str Joblib backend (``"loky"``, ``"threading"``, ``"multiprocessing"``). progress : bool Show a Rich progress bar. description : str Progress bar label. **func_kwargs Extra keyword arguments passed to *func*. Returns ------- list Results in the same order as *items*. Examples -------- >>> def compute(subj_id, k=100): ... mesh = load(subj_id) ... return decompose(mesh, k=k) >>> results = parallel_map(compute, subject_ids, n_jobs=8, k=50) """ joblib = _require_joblib() total = len(items) if not progress: return joblib.Parallel(n_jobs=n_jobs, backend=backend)( joblib.delayed(func)(item, **func_kwargs) for item in items ) # Stream results back in submission order and advance the progress bar # in THIS (parent) process. The progress object holds a thread lock and # must never be captured in the worker closure — doing so breaks pickling # under the process-based ``loky`` backend. results: list[R | None] = [None] * total with progress_parallel(description, total=total) as tick: stream = joblib.Parallel(n_jobs=n_jobs, backend=backend, return_as="generator")( joblib.delayed(func)(item, **func_kwargs) for item in items ) for idx, result in enumerate(stream): results[idx] = result tick(1) return results # type: ignore[return-value]
[docs] def parallel_batch( func: Callable[[np.ndarray], np.ndarray], data: np.ndarray, *, batch_size: int = 1000, n_jobs: int = -1, axis: int = 0, progress: bool = True, description: str = "Batch processing", ) -> np.ndarray: """Apply *func* to batches of an array in parallel. Splits *data* along *axis* into chunks of *batch_size*, applies *func* to each chunk in parallel, and concatenates the results. Useful for operations that are O(N²) per vertex but can be batched (e.g. geodesic distance computation). Parameters ---------- func : callable Function that accepts an ndarray batch and returns an ndarray. data : ndarray Full array to process. batch_size : int Number of rows per batch. n_jobs : int Parallel workers. axis : int Axis along which to split. progress : bool Show progress bar. description : str Progress label. Returns ------- ndarray Concatenated results. """ joblib = _require_joblib() n = data.shape[axis] slices = [slice(i, min(i + batch_size, n)) for i in range(0, n, batch_size)] batches = [np.take(data, range(*s.indices(n)), axis=axis) for s in slices] total = len(batches) if progress: # Stream results and tick in the parent; never pickle the progress # object into worker processes (see parallel_map for the rationale). with progress_parallel(description, total=total) as tick: stream = joblib.Parallel(n_jobs=n_jobs, return_as="generator")( joblib.delayed(func)(b) for b in batches ) results = [] for r in stream: results.append(r) tick(1) else: results = joblib.Parallel(n_jobs=n_jobs)(joblib.delayed(func)(b) for b in batches) return np.concatenate(results, axis=axis)
[docs] def batch_iterator( data: np.ndarray, batch_size: int = 1000, *, axis: int = 0, ) -> Iterator[np.ndarray]: """Iterate over an array in memory-safe batches. Unlike :func:`parallel_batch`, this is a sequential generator suitable for GPU-offloading loops where only one batch should be in memory at a time. Parameters ---------- data : ndarray Array to iterate. batch_size : int Rows per batch. axis : int Axis to split along. Yields ------ ndarray A view (not copy) of the batch. Examples -------- >>> for batch in batch_iterator(big_array, batch_size=500): ... result = expensive_compute(batch) ... accumulate(result) """ n = data.shape[axis] for start in range(0, n, batch_size): end = min(start + batch_size, n) idx = [slice(None)] * data.ndim idx[axis] = slice(start, end) yield data[tuple(idx)]
# ====================================================================== # §4 RAM MEMORY MANAGEMENT # ======================================================================
[docs] @dataclass class MemoryInfo: """Snapshot of system RAM usage. Attributes ---------- total_gb : float Total physical RAM. available_gb : float Available (free + cached) RAM. used_gb : float Actively used RAM. percent_used : float Usage percentage (0–100). """ total_gb: float available_gb: float used_gb: float percent_used: float def __repr__(self) -> str: """Return a human-readable summary.""" return ( f"RAM: {self.used_gb:.1f} / {self.total_gb:.1f} GB " """Return a human-readable RAM status summary.""" f"({self.percent_used:.0f}% used, " f"{self.available_gb:.1f} GB free)" )
[docs] def ram_status() -> MemoryInfo: """Return current system RAM usage. Uses ``/proc/meminfo`` on Linux and ``psutil`` as fallback. Returns ------- MemoryInfo """ # Try /proc/meminfo first (no dependency, Linux only). meminfo_path = Path("/proc/meminfo") if meminfo_path.exists(): info: dict[str, int] = {} with open(meminfo_path) as f: for line in f: parts = line.split() if len(parts) >= 2: key = parts[0].rstrip(":") val_kb = int(parts[1]) info[key] = val_kb total = info.get("MemTotal", 0) / (1024**2) available = info.get("MemAvailable", 0) / (1024**2) used = total - available pct = 100 * used / total if total > 0 else 0 return MemoryInfo(total, available, used, pct) # Fallback: psutil. try: import psutil vm = psutil.virtual_memory() return MemoryInfo( vm.total / (1024**3), vm.available / (1024**3), vm.used / (1024**3), vm.percent, ) except ImportError: logger.warning("Cannot read memory info: /proc/meminfo not found and psutil not installed.") return MemoryInfo(0, 0, 0, 0)
[docs] def gc_collect(generations: int = 2) -> int: """Force Python garbage collection and return bytes freed. Parameters ---------- generations : int GC generations to collect (0, 1, or 2). Returns ------- int Number of unreachable objects collected. Examples -------- >>> del large_array >>> freed = gc_collect() >>> logger.info("GC freed %d objects", freed) """ collected = 0 for gen in range(generations + 1): collected += gc.collect(gen) logger.debug("GC collected %d objects (gen 0–%d)", collected, generations) return collected
[docs] def estimate_array_memory( shape: tuple[int, ...], dtype: np.dtype = np.float64, ) -> float: """Estimate memory for an array in gigabytes. Parameters ---------- shape : tuple of int dtype : numpy dtype Returns ------- float Estimated size in GB. Examples -------- >>> estimate_array_memory((160_000, 300), np.float64) 0.358 # ~358 MB for cortical eigenvectors """ n_elements = 1 for s in shape: n_elements *= s bytes_per_element = np.dtype(dtype).itemsize return n_elements * bytes_per_element / (1024**3)
[docs] @contextmanager def memory_guard( min_available_gb: float = 2.0, error_on_low: bool = False, ) -> Generator[None, None, None]: """Context manager that checks RAM before and after a block. Parameters ---------- min_available_gb : float Minimum free RAM required to proceed. error_on_low : bool Raise ``MemoryError`` if RAM is below threshold. If ``False`` (default), logs a warning instead. Examples -------- >>> with memory_guard(min_available_gb=4.0): ... big_result = compute_all_subjects() """ info = ram_status() if info.available_gb < min_available_gb: msg = ( f"Low RAM: {info.available_gb:.1f} GB available, " f"need {min_available_gb:.1f} GB. " f"Consider closing other applications." ) if error_on_low: raise MemoryError(msg) logger.warning(msg) yield # Post-block: report if memory increased significantly. info_after = ram_status() delta = info.available_gb - info_after.available_gb if delta > 1.0: logger.info( "Block consumed ~%.1f GB RAM (%.1f%.1f GB free)", delta, info.available_gb, info_after.available_gb, )
[docs] def shrink_array( arr: np.ndarray, target_dtype: np.dtype | None = None, ) -> np.ndarray: """Downcast an array to save memory. If *target_dtype* is ``None``, applies safe downcasting rules: float64 → float32, int64 → int32 (if values fit). Parameters ---------- arr : ndarray target_dtype : dtype, optional Returns ------- ndarray A (possibly) smaller copy. """ if target_dtype is not None: return arr.astype(target_dtype, copy=False) if arr.dtype == np.float64: return arr.astype(np.float32) if arr.dtype == np.int64: if arr.min() >= np.iinfo(np.int32).min and arr.max() <= np.iinfo(np.int32).max: return arr.astype(np.int32) return arr
# ====================================================================== # §5 __all__ # ====================================================================== __all__: list[str] = [ # RAM management "MemoryInfo", # Compute backend "NumpyBackend", "NutpieSampler", "PyMCSampler", # Bayesian samplers "SamplerConfig", "batch_iterator", "estimate_array_memory", "gc_collect", "get_bayesian_sampler", "memory_guard", "parallel_batch", # Parallelisation "parallel_map", "ram_status", "shrink_array", ]