Source code for spectralbrain.io.tractseg

"""Import TractSeg outputs for spectral shape analysis.

`TractSeg <https://github.com/MIC-DKFZ/TractSeg>`_ writes one binary
NIfTI mask per white-matter bundle into a ``bundle_segmentations/``
directory (72 bundles in the standard atlas, e.g. ``CST_left.nii.gz``,
``AF_left.nii.gz``). This module turns those masks into the geometric
objects SpectralBrain operates on:

- ``output="pointcloud"`` → a :class:`~spectralbrain.core.pointclouds.BrainPointCloud`
  of the mask's world-space voxel coordinates (ready for point-cloud
  Laplacian spectral analysis).
- ``output="mesh"`` → a :class:`~spectralbrain.core.meshes.BrainMesh`
  isosurface (marching cubes on the binary mask), ready for
  ``.decompose()`` and the mesh descriptors.

Both carry the bundle name and source path in their metadata.

Examples
--------
>>> bundles = load_tractseg("/data/sub-01/tractseg_output", output="mesh")
>>> cst = bundles["CST_left"]
>>> decomp = cst.decompose(k=80)
>>> hks = sb.compute_hks(decomp, t_values=[1, 10, 100])

>>> # A single bundle across a cohort, as point clouds:
>>> files = discover_tractseg_subjects("/data/derivatives/tractseg", "CST_left")
>>> clouds = {sid: load_tractseg_bundle(p) for sid, p in files.items()}
"""

from __future__ import annotations

from pathlib import Path
from typing import Any

import numpy as np

from spectralbrain.runtime import PathLike, get_logger

logger = get_logger(__name__)

#: Subdirectory TractSeg writes bundle masks into.
_DEFAULT_SUBDIR = "bundle_segmentations"


# ======================================================================
# §1  DISCOVERY
# ======================================================================


[docs] def discover_tractseg_bundles( tractseg_dir: PathLike, *, bundles: list[str] | None = None, subdir: str | None = _DEFAULT_SUBDIR, ) -> dict[str, Path]: """Find bundle-segmentation masks in one subject's TractSeg output. Parameters ---------- tractseg_dir : PathLike TractSeg output directory for a subject. bundles : list of str, optional Restrict to these bundle names (without extension, e.g. ``"CST_left"``). Defaults to every ``*.nii.gz`` mask found. subdir : str, optional Subdirectory holding the masks (default ``"bundle_segmentations"``). Pass ``None`` to search *tractseg_dir* directly. Returns ------- dict of {bundle_name: Path} """ root = Path(tractseg_dir) search_dir = root / subdir if subdir else root if not search_dir.is_dir(): raise FileNotFoundError(f"No TractSeg mask directory at {search_dir}") found: dict[str, Path] = {} for p in sorted(search_dir.glob("*.nii*")): name = p.name.split(".")[0] found[name] = p if bundles is not None: missing = [b for b in bundles if b not in found] for b in missing: logger.warning("Bundle %r not found in %s", b, search_dir) found = {b: found[b] for b in bundles if b in found} logger.info("TractSeg: %d bundle masks in %s", len(found), search_dir) return found
[docs] def discover_tractseg_subjects( root: PathLike, bundle: str, *, pattern: str = "sub-{sub}/tractseg_output", subdir: str | None = _DEFAULT_SUBDIR, subjects: list[str] | None = None, ) -> dict[str, Path]: """Find one bundle's mask across subjects in a derivatives tree. Parameters ---------- root : PathLike Derivatives root containing per-subject TractSeg outputs. bundle : str Bundle name (e.g. ``"CST_left"``). pattern : str Per-subject TractSeg directory relative to *root*, with a ``{sub}`` placeholder (bare label). subdir : str, optional Mask subdirectory within each subject's TractSeg output. subjects : list of str, optional Restrict to these subjects. Returns ------- dict of {subject_id: Path} One bundle mask per subject. """ root = Path(root) if subjects is None: labels = sorted(p.name[4:] for p in root.glob("sub-*") if p.is_dir()) else: labels = [s[4:] if s.startswith("sub-") else s for s in subjects] found: dict[str, Path] = {} for label in labels: ts_dir = root / pattern.replace("{sub}", label) search_dir = ts_dir / subdir if subdir else ts_dir matches = sorted(search_dir.glob(f"{bundle}.nii*")) if search_dir.is_dir() else [] if not matches: logger.warning("No %s mask for sub-%s", bundle, label) continue found[f"sub-{label}"] = matches[0] logger.info("TractSeg cohort: %d subjects with bundle %r", len(found), bundle) return found
# ====================================================================== # §2 LOADING # ======================================================================
[docs] def load_tractseg_bundle( mask_path: PathLike, *, output: str = "pointcloud", level: float = 0.5, jitter: bool = False, jitter_scale: float = 0.25, seed: int | None = None, step_size: int = 1, ) -> Any: """Load one TractSeg bundle mask as a point cloud or isosurface mesh. Parameters ---------- mask_path : PathLike A binary (or probabilistic) bundle mask NIfTI. output : ``"pointcloud"`` or ``"mesh"`` ``"pointcloud"`` returns a :class:`BrainPointCloud` of the mask's world-space voxel coordinates; ``"mesh"`` returns a :class:`BrainMesh` isosurface via marching cubes. level : float Threshold separating inside/outside the bundle (default ``0.5``; appropriate for binary masks and TractSeg probability maps). jitter, jitter_scale, seed : Point-cloud only — optional sub-voxel jitter to break the regular grid (helps point-cloud Laplacian estimation). step_size : int Mesh only — marching-cubes step (larger = coarser/faster). Returns ------- BrainPointCloud or BrainMesh With ``metadata["bundle"]`` and ``metadata["source"]`` set. """ from spectralbrain.io.loaders import labels_to_pointcloud, load_nifti path = Path(mask_path) vol, affine = load_nifti(path) binary = (np.asarray(vol) > level).astype(np.int16) if binary.sum() == 0: raise ValueError(f"Empty mask (no voxels above {level}) in {path.name}") bundle = path.name.split(".")[0] meta = {"bundle": bundle, "source": str(path)} if output == "pointcloud": from spectralbrain.core.pointclouds import BrainPointCloud pts = labels_to_pointcloud( binary, affine, label_id=1, jitter=jitter, jitter_scale=jitter_scale, seed=seed, ) return BrainPointCloud(pts, metadata={**meta, "n_voxels": int(binary.sum())}) if output == "mesh": from spectralbrain.core.base import marching_cubes from spectralbrain.core.meshes import BrainMesh verts, faces = marching_cubes( binary.astype(np.float32), affine, level=0.5, step_size=step_size ) return BrainMesh(verts, faces, metadata=meta) raise ValueError(f"Unknown output {output!r}; use 'pointcloud' or 'mesh'.")
[docs] def load_tractseg( tractseg_dir: PathLike, *, bundles: list[str] | None = None, output: str = "pointcloud", subdir: str | None = _DEFAULT_SUBDIR, **kwargs: Any, ) -> dict[str, Any]: """Load all (or selected) bundles from one subject's TractSeg output. Parameters ---------- tractseg_dir : PathLike TractSeg output directory for a subject. bundles : list of str, optional Restrict to these bundle names. Defaults to all masks found. output : ``"pointcloud"`` or ``"mesh"`` Geometric representation per bundle. subdir : str, optional Mask subdirectory (default ``"bundle_segmentations"``). **kwargs Forwarded to :func:`load_tractseg_bundle` (``level``, ``jitter``, ``step_size``, …). Returns ------- dict of {bundle_name: BrainPointCloud or BrainMesh} Bundles that fail to load (e.g. empty masks) are logged and skipped. """ files = discover_tractseg_bundles(tractseg_dir, bundles=bundles, subdir=subdir) out: dict[str, Any] = {} for name, path in files.items(): try: out[name] = load_tractseg_bundle(path, output=output, **kwargs) except Exception as exc: logger.error("✗ %s: %s", name, exc) logger.info("Loaded %d/%d TractSeg bundles as %s.", len(out), len(files), output) return out
__all__ = [ "discover_tractseg_bundles", "discover_tractseg_subjects", "load_tractseg", "load_tractseg_bundle", ]