"""SpectralBrain runtime infrastructure.
Provides cross-cutting services consumed by all other modules:
versioning, canonical type aliases, structured logging, Rich progress
bars for various workloads, and Singularity/Apptainer container
management for optional DL-based preprocessing.
This module has **no** intra-library imports so it can be imported
first without circular dependencies.
Examples
--------
>>> from spectralbrain.runtime import __version__, get_logger
>>> logger = get_logger("spectralbrain.core")
>>> logger.info("Loaded SpectralBrain %s", __version__)
>>> from spectralbrain.runtime import progress_simple
>>> with progress_simple("Computing HKS", total=20) as update:
... for i, t in enumerate(t_values):
... hks[:, i] = _hks_at_t(eigenvalues, eigenvectors, t)
... update(1)
"""
from __future__ import annotations
# ──────────────────────────────────────────────────────────────────────
# Standard library
# ──────────────────────────────────────────────────────────────────────
import hashlib
import logging
import os
import shutil
import subprocess
import sys
import urllib.request
from collections.abc import Callable, Generator
from contextlib import contextmanager
from dataclasses import dataclass
from enum import Enum, auto
from pathlib import Path
from typing import (
Any,
TypeVar,
)
# ──────────────────────────────────────────────────────────────────────
# Third-party (hard deps: numpy, scipy; soft dep: rich)
# ──────────────────────────────────────────────────────────────────────
import numpy as np
import numpy.typing as npt
import scipy.sparse as sp
try:
from rich.console import Console
from rich.logging import RichHandler
from rich.progress import (
BarColumn,
MofNCompleteColumn,
Progress,
SpinnerColumn,
TaskProgressColumn,
TextColumn,
TimeElapsedColumn,
TimeRemainingColumn,
)
_HAS_RICH = True
except ImportError: # pragma: no cover
_HAS_RICH = False
# ======================================================================
# §1 VERSIONING
# ======================================================================
__version__: str = "0.1.0-dev"
"""Semantic version string — single source of truth.
Surfaced as ``spectralbrain.__version__`` via the package
``__init__.py``. Build tooling reads this at release time.
"""
VERSION_INFO: tuple[int, int, int] = (0, 1, 0)
"""(major, minor, patch) as a comparable tuple."""
# ======================================================================
# §2 CANONICAL TYPE ALIASES
# ======================================================================
#
# Every array flowing through SpectralBrain should be annotated with
# one of these aliases. Static analysers (mypy/pyright) and IDE
# autocompletion use them; runtime code treats them as plain ndarrays.
#
# Naming convention: CamelCase nouns, always NDArray-backed.
# ── Geometric primitives ──────────────────────────────────────────────
Vertices = npt.NDArray[np.floating]
"""Vertex coordinates, shape ``(N, 3)``, float64, in mm (RAS)."""
Faces = npt.NDArray[np.intp]
"""Triangle face indices, shape ``(F, 3)``, int64, **0-indexed**."""
Points = npt.NDArray[np.floating]
"""Point-cloud coordinates, shape ``(N, 3)``, float64.
Semantically identical to :pydata:`Vertices` but used when there is
no face connectivity (volumetric segmentation → voxel centroids).
"""
Normals = npt.NDArray[np.floating]
"""Unit normals, shape ``(N, 3)``, float64 — per-vertex or per-point."""
# ── Spectral primitives ──────────────────────────────────────────────
Eigenvalues = npt.NDArray[np.floating]
"""LBO eigenvalue vector, shape ``(k,)``, float64, ascending.
Non-negative; λ₀ ≈ 0 (constant mode). λ₁ is the Fiedler value.
"""
Eigenvectors = npt.NDArray[np.floating]
"""LBO eigenvectors, shape ``(N, k)``, float64.
Column *i* is eigenfunction φᵢ evaluated at each vertex/point.
Orthonormal w.r.t. mass matrix: Φᵀ M Φ = I.
"""
SparseMatrix = sp.spmatrix
"""Any SciPy sparse matrix (CSR, CSC, COO).
The stiffness (Laplacian) matrix *L* and the mass matrix *M* live
in this format. Backends convert to their native sparse types.
"""
MassMatrix = sp.spmatrix
"""Lumped or consistent mass matrix, shape ``(N, N)``, sparse."""
# ── Descriptor outputs ────────────────────────────────────────────────
ScalarMap = npt.NDArray[np.floating]
"""Per-vertex scalar, shape ``(N,)`` — e.g. one HKS time-slice,
Casorati curvature, or a Bayesian surprise score.
"""
DescriptorMatrix = npt.NDArray[np.floating]
"""Multi-scale per-vertex descriptor, shape ``(N, T)`` — e.g. HKS
evaluated at *T* time-scales, or WKS at *T* energies.
"""
GlobalDescriptor = npt.NDArray[np.floating]
"""One vector per shape, shape ``(d,)`` — e.g. ShapeDNA, 3D Zernike
moments, or a Fisher-vector aggregation.
"""
DistanceMatrix = npt.NDArray[np.floating]
"""Pairwise distance/similarity matrix, shape ``(R, R)`` — e.g.
ROI-to-ROI WESD in a geometric connectome.
"""
# ── Neuroimaging types ────────────────────────────────────────────────
LabelArray = npt.NDArray[np.integer]
"""Per-vertex / per-voxel integer label, shape ``(N,)`` — atlas ROI
indices (Schaefer, aseg, etc.).
"""
VolumeImage = Any # nibabel.nifti1.Nifti1Image (lazy import)
"""NIfTI volume — typed as ``Any`` to avoid hard nibabel dep at import."""
SurfaceImage = Any # nibabel.gifti.GiftiImage (lazy import)
"""GIfTI surface — typed as ``Any`` to avoid hard nibabel dep at import."""
# ── Analysis / connectome ─────────────────────────────────────────────
ConnectomeMatrix = npt.NDArray[np.floating]
"""ROI × ROI spectral similarity/distance, shape ``(R, R)``."""
NetworkMatrix = npt.NDArray[np.floating]
"""Network × Network summary, shape ``(K, K)`` — block-averaged
:pydata:`ConnectomeMatrix` (e.g. 7×7 for Yeo networks).
"""
# ── Generic helpers ───────────────────────────────────────────────────
PathLike = str | os.PathLike
"""Anything :class:`pathlib.Path` can consume."""
T = TypeVar("T")
"""Generic type variable used in container utilities."""
# ======================================================================
# §2b SUPPORTED FORMATS, ATLASES, DESCRIPTORS, OBJECTIVES, BACKENDS
# ======================================================================
[docs]
class AtlasScheme(Enum):
"""Brain atlases supported by ``utils.atlas``."""
SCHAEFER_100 = "schaefer100"
SCHAEFER_200 = "schaefer200"
SCHAEFER_400 = "schaefer400"
SCHAEFER_600 = "schaefer600"
SCHAEFER_800 = "schaefer800"
SCHAEFER_1000 = "schaefer1000"
DKT = "dkt"
DESTRIEUX = "destrieux"
ASEG = "aseg"
THALAMIC_NUCLEI = "thalamic_nuclei"
AMYGDALA_NUCLEI = "amygdala_nuclei"
HIPPOCAMPAL_SUBFIELDS = "hippocampal_subfields"
JULICH_BRAIN = "julich_brain"
TIAN_S1 = "tian_s1"
TIAN_S2 = "tian_s2"
TIAN_S3 = "tian_s3"
TIAN_S4 = "tian_s4"
BRAINNETOME = "brainnetome"
GLASSER_MMP = "glasser_mmp"
[docs]
class DescriptorType(Enum):
"""Descriptor identifiers — used by ``recommend_descriptor()``
and the eligibility registry.
"""
# LBO-based (spectral/descriptors.py)
SHAPEDNA = "shapedna"
HKS = "hks"
SI_HKS = "si_hks"
WKS = "wks"
GPS = "gps"
BATES_SP = "bates_sp"
BKS = "bks"
IBKS = "ibks"
# Distances (spectral/distances.py)
WESD = "wesd"
BIHARMONIC = "biharmonic"
COMMUTE_TIME = "commute_time"
DIFFUSION = "diffusion"
# Wavelets (spectral/wavelets.py)
SGW_MEXICAN_HAT = "sgw_mexican_hat"
SGW_HEAT = "sgw_heat"
# Anisotropic (spectral/anisotropic.py)
FINSLER_HKS = "finsler_hks"
ASMWD = "asmwd"
# Collection-aware (spectral/collections.py)
DWKS = "dwks"
# Curvature-based
SHAPE_INDEX = "shape_index"
CASORATI = "casorati"
WILLMORE_ENERGY = "willmore_energy"
# Integral / metric
INTEGRAL_INVARIANT = "integral_invariant"
ZERNIKE_3D = "zernike_3d"
SDF = "sdf"
AGD = "agd"
ECCENTRICITY = "eccentricity"
# Topological
ECT = "ect"
PHT = "pht"
# Information-theoretic
FRACTAL_DIM = "fractal_dim"
# ── Eligibility registry for recommend_descriptor() ───────────────────
DESCRIPTOR_ELIGIBILITY: dict[str, list[str]] = {
"group_discrimination": [
"shapedna",
"hks",
"wks",
"si_hks",
"bates_sp",
"bks",
"wesd",
"sgw_mexican_hat",
"casorati",
"integral_invariant",
"zernike_3d",
"ect",
"fractal_dim",
],
"lateralization": [
"shapedna",
"hks",
"wks",
"bates_sp",
"gps",
"bks",
"biharmonic",
"sgw_mexican_hat",
"casorati",
"integral_invariant",
"eccentricity",
"ect",
],
"longitudinal_change": [
"shapedna",
"hks",
"wks",
"bates_sp",
"gps",
"bks",
"wesd",
"diffusion",
"dwks",
"casorati",
"integral_invariant",
"ect",
"fractal_dim",
],
"subregion_detection": [
"hks",
"wks",
"gps",
"bks",
"sgw_mexican_hat",
"biharmonic",
"dwks",
"finsler_hks",
"shape_index",
"casorati",
"sdf",
"agd",
"eccentricity",
],
}
[docs]
class AnalysisObjective(Enum):
"""Objectives for ``recommend_descriptor()``."""
GROUP_DISCRIMINATION = "group_discrimination"
LATERALIZATION = "lateralization"
LONGITUDINAL_CHANGE = "longitudinal_change"
SUBREGION_DETECTION = "subregion_detection"
[docs]
class BackendName(Enum):
"""Compute backends."""
NUMPY = "numpy"
JAX = "jax"
CUPY = "cupy"
TORCH = "torch"
# ======================================================================
# §3 STRUCTURED LOGGING
# ======================================================================
_CONSOLE: Console | None = Console(stderr=True) if _HAS_RICH else None
_LIB_LOGGER_NAME: str = "spectralbrain"
[docs]
def get_logger(
name: str = _LIB_LOGGER_NAME,
*,
level: int = logging.INFO,
rich: bool = True,
) -> logging.Logger:
"""Return a configured logger for a SpectralBrain module.
Installs a :class:`rich.logging.RichHandler` on first call (if
Rich is available). Subsequent calls with the same *name* return
the existing logger.
Parameters
----------
name : str
Logger name — submodules should pass ``__name__``.
level : int
Logging level (default ``logging.INFO``).
rich : bool
Use Rich formatting when available.
Returns
-------
logging.Logger
"""
logger = logging.getLogger(name)
if logger.handlers:
return logger
logger.setLevel(level)
logger.propagate = False
if rich and _HAS_RICH:
handler = RichHandler(
console=_CONSOLE,
show_path=False,
show_time=True,
markup=True,
rich_tracebacks=True,
tracebacks_show_locals=False,
)
handler.setFormatter(logging.Formatter("%(message)s"))
else:
handler = logging.StreamHandler(sys.stderr)
handler.setFormatter(
logging.Formatter(
"[%(asctime)s] %(levelname)-8s %(name)s — %(message)s",
datefmt="%H:%M:%S",
)
)
logger.addHandler(handler)
return logger
[docs]
def set_log_level(level: int | str) -> None:
"""Set the log level for the entire library.
Parameters
----------
level : int or str
E.g. ``logging.DEBUG`` or ``"WARNING"``.
"""
numeric = level if isinstance(level, int) else getattr(logging, level.upper())
logging.getLogger(_LIB_LOGGER_NAME).setLevel(numeric)
# ======================================================================
# §4 RICH PROGRESS BARS
# ======================================================================
def _fallback_update_factory(
description: str,
total: int | None,
) -> Generator[Callable[[int], None], None, None]:
"""Plain-text fallback when Rich is not installed."""
done = 0
def _update(n: int = 1) -> None:
"""Update the progress state."""
nonlocal done
done += n
if total and done % max(1, total // 10) == 0:
print(f"\r {description}: {100 * done / total:.0f}%", end="", flush=True)
yield _update
print()
[docs]
@contextmanager
def progress_simple(
description: str = "Processing",
total: int | None = None,
) -> Generator[Callable[[int], None], None, None]:
"""Simple progress bar with ETA.
Parameters
----------
description : str
Label shown left of the bar.
total : int or None
Step count. ``None`` → indeterminate spinner.
Yields
------
Callable[[int], None]
``update(n)`` advances the bar by *n* steps.
Examples
--------
>>> with progress_simple("Eigensolve", total=n_structures) as tick:
... for s in structures:
... decompose(s)
... tick(1)
"""
if not _HAS_RICH:
yield from _fallback_update_factory(description, total)
return
cols = [
SpinnerColumn(),
TextColumn("[bold blue]{task.description}"),
BarColumn(bar_width=40),
TaskProgressColumn(),
MofNCompleteColumn(),
TimeElapsedColumn(),
TimeRemainingColumn(),
]
with Progress(*cols, console=_CONSOLE, transient=True) as prog:
tid = prog.add_task(description, total=total)
def _update(n: int = 1) -> None:
"""Update the progress state."""
prog.update(tid, advance=n)
yield _update
[docs]
@contextmanager
def progress_parallel(
description: str = "Parallel jobs",
total: int | None = None,
) -> Generator[Callable[[int], None], None, None]:
"""Thread-safe progress bar for ``joblib.Parallel`` callbacks.
Parameters
----------
description : str
Label.
total : int or None
Total number of jobs.
Yields
------
Callable[[int], None]
Thread-safe ``update(n)``.
Examples
--------
>>> from joblib import Parallel, delayed
>>> with progress_parallel("Subjects", total=228) as tick:
... def _run(s):
... result = process(s); tick(1); return result
... Parallel(n_jobs=8)(delayed(_run)(s) for s in subjects)
"""
if not _HAS_RICH:
yield from _fallback_update_factory(description, total)
return
cols = [
SpinnerColumn("dots"),
TextColumn("[bold magenta]{task.description}"),
BarColumn(bar_width=40),
MofNCompleteColumn(),
TimeElapsedColumn(),
TimeRemainingColumn(),
]
with Progress(*cols, console=_CONSOLE, transient=False, refresh_per_second=10) as prog:
tid = prog.add_task(description, total=total)
def _update(n: int = 1) -> None:
"""Update the progress state."""
prog.update(tid, advance=n)
yield _update
[docs]
class NestedProgress:
"""Two-level Rich progress for nested loops.
Parameters
----------
outer_description : str
Outer-loop label (e.g. ``"Subjects"``).
outer_total : int
Number of outer iterations.
inner_description : str
Inner-loop label (e.g. ``"Structures"``).
inner_total : int
Inner iterations **per** outer step.
Examples
--------
>>> with NestedProgress("Subjects", 228, "ROIs", 44) as np:
... for subj in subjects:
... for roi in rois:
... compute(subj, roi)
... np.advance_inner()
... np.advance_outer()
"""
def __init__(
self,
outer_description: str,
outer_total: int,
inner_description: str,
inner_total: int,
) -> None:
"""Initialise the nested progress tracker."""
self.outer_description = outer_description
self.outer_total = outer_total
self.inner_description = inner_description
self.inner_total = inner_total
self._progress: Progress | None = None
self._outer_id: int | None = None
self._inner_id: int | None = None
# -- context manager ------------------------------------------------
def __enter__(self) -> NestedProgress:
"""Enter the context manager and start the progress bar."""
if not _HAS_RICH:
return self
cols = [
SpinnerColumn(),
TextColumn("[bold]{task.description}"),
BarColumn(bar_width=30),
MofNCompleteColumn(),
TimeElapsedColumn(),
TimeRemainingColumn(),
]
self._progress = Progress(*cols, console=_CONSOLE, transient=False)
self._progress.__enter__()
self._outer_id = self._progress.add_task(
f"[cyan]{self.outer_description}",
total=self.outer_total,
)
self._inner_id = self._progress.add_task(
f" [green]{self.inner_description}",
total=self.inner_total,
)
return self
def __exit__(self, *exc: Any) -> None:
"""Exit the context manager and stop the progress bar."""
if self._progress is not None:
self._progress.__exit__(*exc)
# -- public API -----------------------------------------------------
[docs]
def advance_inner(self, n: int = 1) -> None:
"""Advance inner bar by *n* steps."""
if self._progress is not None and self._inner_id is not None:
self._progress.update(self._inner_id, advance=n)
[docs]
def advance_outer(self, n: int = 1) -> None:
"""Advance outer bar by *n* and reset the inner bar."""
if self._progress is None:
return
if self._outer_id is not None:
self._progress.update(self._outer_id, advance=n)
if self._inner_id is not None:
self._progress.reset(self._inner_id)
[docs]
@contextmanager
def progress_spinner(
description: str = "Working",
) -> Generator[None, None, None]:
"""Indeterminate spinner for unknown-duration operations.
Parameters
----------
description : str
Label next to the spinner.
Examples
--------
>>> with progress_spinner("Downloading container"):
... download_large_file(url, dest)
"""
if not _HAS_RICH:
print(f" {description}…", end="", flush=True)
yield
print(" done.")
return
cols = [
SpinnerColumn("dots12"),
TextColumn("[bold yellow]{task.description}"),
TimeElapsedColumn(),
]
with Progress(*cols, console=_CONSOLE, transient=True) as prog:
prog.add_task(description, total=None)
yield
# ======================================================================
# §5 CONTAINER MANAGER (Singularity / Apptainer)
# ======================================================================
_logger = get_logger(f"{_LIB_LOGGER_NAME}.runtime")
_DEFAULT_CACHE_DIR: Path = Path(
os.environ.get(
"SPECTRALBRAIN_CACHE",
str(Path.home() / ".cache" / "spectralbrain" / "containers"),
)
)
[docs]
@dataclass
class ContainerSpec:
"""Specification for one DL preprocessing container.
Parameters
----------
name : str
Human-readable tool name.
sif_filename : str
Local cache filename.
source_url : str
HTTPS download URL.
sha256 : str
Expected SHA-256 digest for integrity check.
size_mb : int
Approximate download size (shown to user).
entrypoint : str
Command template — use ``{input}`` and ``{output}`` placeholders.
gpu_required : bool
Needs ``--nv`` GPU passthrough.
"""
name: str
sif_filename: str
source_url: str
sha256: str
size_mb: int
entrypoint: str
gpu_required: bool = True
# ── Default registry (placeholders until GHCR images are built) ──────
CONTAINER_REGISTRY: dict[str, ContainerSpec] = {
"hdbet": ContainerSpec(
name="HD-BET",
sif_filename="spectralbrain_hdbet_v1.0.sif",
source_url=(
"https://github.com/rdneuro/spectralbrain-containers"
"/releases/download/v0.1.0/spectralbrain_hdbet_v1.0.sif"
),
sha256="placeholder",
size_mb=2400,
entrypoint="hd-bet -i {input} -o {output} -mode fast -tta 0",
),
"synthseg": ContainerSpec(
name="SynthSeg",
sif_filename="spectralbrain_synthseg_v2.0.sif",
source_url=(
"https://github.com/rdneuro/spectralbrain-containers"
"/releases/download/v0.1.0/spectralbrain_synthseg_v2.0.sif"
),
sha256="placeholder",
size_mb=1800,
entrypoint="mri_synthseg --i {input} --o {output} --robust",
),
"fastsurfer": ContainerSpec(
name="FastSurfer",
sif_filename="spectralbrain_fastsurfer_v2.3.sif",
source_url=(
"https://github.com/rdneuro/spectralbrain-containers"
"/releases/download/v0.1.0/spectralbrain_fastsurfer_v2.3.sif"
),
sha256="placeholder",
size_mb=3500,
entrypoint=("run_fastsurfer.sh --t1 {input} --sd {output} --seg_only"),
),
}
def _detect_runtime() -> str | None:
"""Find ``apptainer`` or ``singularity`` on PATH."""
for name in ("apptainer", "singularity"):
path = shutil.which(name)
if path is not None:
return path
return None
def _has_nvidia_gpu() -> bool:
"""Return True if ``nvidia-smi`` exits successfully."""
try:
subprocess.run(
["nvidia-smi"],
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
check=True,
)
return True
except (FileNotFoundError, subprocess.CalledProcessError):
return False
def _sha256(path: Path, chunk_size: int = 1 << 20) -> str:
"""Compute SHA-256 hex digest of a file."""
h = hashlib.sha256()
with open(path, "rb") as fh:
while chunk := fh.read(chunk_size):
h.update(chunk)
return h.hexdigest()
[docs]
class ContainerManager:
"""Download, cache, verify, and execute Singularity containers.
Containers are ``.sif`` files stored under a local cache directory
(default ``~/.cache/spectralbrain/containers/``; override with the
``SPECTRALBRAIN_CACHE`` environment variable). Each is downloaded
once on first use and verified via SHA-256.
Parameters
----------
cache_dir : PathLike
Container storage directory.
registry : dict, optional
Tool name → :class:`ContainerSpec` mapping.
Examples
--------
>>> cm = ContainerManager()
>>> cm.status()
>>> cm.run("hdbet",
... input_path="sub-01_T1w.nii.gz",
... output_path="sub-01_brain.nii.gz")
"""
def __init__(
self,
cache_dir: PathLike = _DEFAULT_CACHE_DIR,
registry: dict[str, ContainerSpec] | None = None,
) -> None:
"""Initialise the container runner configuration."""
self.cache_dir = Path(cache_dir)
self.cache_dir.mkdir(parents=True, exist_ok=True)
self.registry = registry or CONTAINER_REGISTRY
self._runtime: str | None = _detect_runtime()
# ── properties ────────────────────────────────────────────────────
@property
def runtime_available(self) -> bool:
"""True if Singularity/Apptainer is installed."""
return self._runtime is not None
@property
def runtime_name(self) -> str:
"""``'apptainer'``, ``'singularity'``, or ``'none'``."""
return Path(self._runtime).stem if self._runtime else "none"
# ── internal helpers ──────────────────────────────────────────────
def _sif_path(self, tool: str) -> Path:
"""Resolve the Singularity/Apptainer image path."""
return self.cache_dir / self.registry[tool].sif_filename
# ── public API ────────────────────────────────────────────────────
[docs]
def is_cached(self, tool: str) -> bool:
"""Check whether a container is already downloaded."""
return self._sif_path(tool).exists()
[docs]
def ensure(self, tool: str) -> Path:
"""Download *tool* container if not cached.
Parameters
----------
tool : str
Registry key (e.g. ``"hdbet"``).
Returns
-------
Path
Local ``.sif`` path.
Raises
------
KeyError
Unknown tool name.
RuntimeError
Download or checksum failure.
"""
if tool not in self.registry:
available = ", ".join(self.registry)
raise KeyError(f"Unknown container '{tool}'. Available: {available}")
sif = self._sif_path(tool)
if sif.exists():
_logger.info("Container [bold]%s[/] already cached.", tool)
return sif
spec = self.registry[tool]
_logger.info(
"Downloading [bold]%s[/] (~%d MB) — first time only.",
spec.name,
spec.size_mb,
)
tmp = sif.with_suffix(".part")
try:
with progress_spinner(f"Downloading {spec.name}"):
urllib.request.urlretrieve(spec.source_url, tmp)
except Exception as exc:
tmp.unlink(missing_ok=True)
raise RuntimeError(f"Download failed for {spec.name}: {exc}") from exc
if spec.sha256 != "placeholder":
digest = _sha256(tmp)
if digest != spec.sha256:
tmp.unlink(missing_ok=True)
raise RuntimeError(
f"SHA-256 mismatch for {spec.name}: "
f"expected {spec.sha256[:16]}…, got {digest[:16]}…"
)
tmp.rename(sif)
_logger.info("Cached [bold]%s[/] → %s", spec.name, sif)
return sif
[docs]
def run(
self,
tool: str,
*,
input_path: PathLike,
output_path: PathLike,
extra_binds: list[str] | None = None,
extra_args: list[str] | None = None,
gpu: bool | None = None,
) -> subprocess.CompletedProcess:
"""Execute a containerised preprocessing tool.
Parameters
----------
tool : str
Registry key.
input_path : PathLike
Input file (bind-mounted read-only).
output_path : PathLike
Output file (parent dir bind-mounted read-write).
extra_binds : list of str, optional
Additional ``"host:container"`` bind specs.
extra_args : list of str, optional
Arguments appended to the entrypoint.
gpu : bool or None
Force GPU on/off; ``None`` = auto-detect.
Returns
-------
subprocess.CompletedProcess
Raises
------
EnvironmentError
No container runtime found.
subprocess.CalledProcessError
Container exited with non-zero status.
"""
if self._runtime is None:
raise OSError(
"No container runtime found. Install Apptainer:\n"
" https://apptainer.org/docs/admin/latest/installation.html"
)
sif = self.ensure(tool)
spec = self.registry[tool]
inp = Path(input_path).resolve()
out = Path(output_path).resolve()
out.parent.mkdir(parents=True, exist_ok=True)
cmd: list[str] = [self._runtime, "exec"]
use_gpu = gpu if gpu is not None else (spec.gpu_required and _has_nvidia_gpu())
if use_gpu:
cmd.append("--nv")
cmd += ["--bind", f"{inp.parent}:/input:ro", "--bind", f"{out.parent}:/output:rw"]
for b in extra_binds or []:
cmd += ["--bind", b]
cmd.append(str(sif))
cmd += spec.entrypoint.format(
input=f"/input/{inp.name}",
output=f"/output/{out.name}",
).split()
cmd += extra_args or []
_logger.info(
"Running [bold]%s[/] (%s)",
spec.name,
"GPU" if use_gpu else "CPU",
)
_logger.debug("$ %s", " ".join(cmd))
return subprocess.run(cmd, check=True, capture_output=True, text=True)
[docs]
def status(self) -> dict[str, dict[str, Any]]:
"""Print and return status of all registered containers."""
report: dict[str, dict[str, Any]] = {}
for name, spec in self.registry.items():
sif = self._sif_path(name)
report[name] = {
"cached": sif.exists(),
"path": str(sif) if sif.exists() else None,
"size_mb": spec.size_mb,
"gpu": spec.gpu_required,
}
if _HAS_RICH and _CONSOLE is not None:
_CONSOLE.rule("[bold]SpectralBrain containers")
_CONSOLE.print(f"Runtime : [bold]{self.runtime_name}[/]")
_CONSOLE.print(f"Cache : {self.cache_dir}\n")
for name, info in report.items():
ok = "✓" if info["cached"] else "✗"
c = "green" if info["cached"] else "red"
_CONSOLE.print(
f" [{c}]{ok}[/] [bold]{name:12s}[/] "
f"{info['size_mb']:>5d} MB "
f"{'GPU' if info['gpu'] else 'CPU'}"
)
return report
[docs]
def clean(self, tool: str | None = None) -> None:
"""Remove cached container(s).
Parameters
----------
tool : str or None
Specific tool, or ``None`` to clear all.
"""
targets = [tool] if tool else list(self.registry)
for t in targets:
sif = self._sif_path(t)
if sif.exists():
sif.unlink()
_logger.info("Removed %s", sif)
# ======================================================================
# §6 __all__
# ======================================================================
__all__: list[str] = [
"CONTAINER_REGISTRY",
# §2b Eligibility
"DESCRIPTOR_ELIGIBILITY",
"VERSION_INFO",
"AnalysisObjective",
"AtlasScheme",
"BackendName",
# §2 Types — analysis
"ConnectomeMatrix",
"ContainerManager",
# §5 Containers
"ContainerSpec",
"DescriptorMatrix",
"DescriptorType",
"DistanceMatrix",
# §2 Types — spectral
"Eigenvalues",
"Eigenvectors",
"Faces",
# §2b Enums
"GeometryFormat",
"GlobalDescriptor",
# §2 Types — neuroimaging
"LabelArray",
"MassMatrix",
"NestedProgress",
"NetworkMatrix",
"Normals",
# §2 Types — generic
"PathLike",
"Points",
# §2 Types — descriptors
"ScalarMap",
"SparseMatrix",
"SurfaceImage",
# §2 Types — geometric
"Vertices",
"VolumeImage",
# §1 Versioning
"__version__",
# §3 Logging
"get_logger",
"progress_parallel",
# §4 Progress
"progress_simple",
"progress_spinner",
"set_log_level",
]