Source code for spectralbrain.utils.datasets

"""Template data, example datasets, and public dataset fetchers.

Provides quick access to template geometries (fsaverage, MNI152)
and synthetic example datasets for tutorials and testing.
"""

from __future__ import annotations

from pathlib import Path
from typing import Any

import numpy as np

from spectralbrain.runtime import Faces, Vertices, get_logger

logger = get_logger(__name__)

_CACHE_DIR = Path.home() / ".cache" / "spectralbrain" / "datasets"


# ======================================================================
# §1  SYNTHETIC EXAMPLE DATASETS
# ======================================================================


[docs] def make_two_group_example( n_per_group: int = 30, n_vertices: int = 500, n_scales: int = 10, effect_size: float = 0.5, seed: int = 42, ) -> dict[str, Any]: """Generate a two-group synthetic dataset for tutorials. Creates descriptors for controls and patients with a focal spectral difference at a subset of vertices. Parameters ---------- n_per_group : int n_vertices : int n_scales : int effect_size : float Cohen's d of the planted difference. seed : int Returns ------- dict Keys: ``"controls"`` (n, N, T), ``"patients"`` (n, N, T), ``"labels"`` (2n,), ``"affected_vertices"`` (bool mask), ``"ages"`` (2n,), ``"metadata"``. """ rng = np.random.default_rng(seed) controls = rng.normal(0, 1, (n_per_group, n_vertices, n_scales)) patients = rng.normal(0, 1, (n_per_group, n_vertices, n_scales)) # Plant a focal effect at ~20% of vertices. affected = np.zeros(n_vertices, dtype=bool) affected[rng.choice(n_vertices, n_vertices // 5, replace=False)] = True patients[:, affected, :] += effect_size labels = np.array([0] * n_per_group + [1] * n_per_group) ages = rng.normal(45, 15, 2 * n_per_group).clip(18, 85) return { "controls": controls, "patients": patients, "labels": labels, "affected_vertices": affected, "ages": ages, "metadata": { "n_per_group": n_per_group, "effect_size": effect_size, "n_affected": int(affected.sum()), }, }
[docs] def make_normative_example( n_subjects: int = 200, n_vertices: int = 500, age_range: tuple[float, float] = (20, 80), seed: int = 42, ) -> dict[str, Any]: """Generate a synthetic normative cohort with age effects. Parameters ---------- n_subjects : int n_vertices : int age_range : tuple seed : int Returns ------- dict Keys: ``"descriptors"`` (S, N), ``"ages"`` (S,), ``"sex"`` (S,), ``"age_slope"`` (N,). """ rng = np.random.default_rng(seed) ages = rng.uniform(*age_range, n_subjects) sex = rng.binomial(1, 0.5, n_subjects) # Per-vertex age slope (some positive, some negative). age_slope = rng.normal(0, 0.01, n_vertices) descriptors = ( rng.normal(0, 1, (n_subjects, n_vertices)) + ages[:, None] * age_slope[None, :] + sex[:, None] * rng.normal(0, 0.1, n_vertices)[None, :] ) return { "descriptors": descriptors, "ages": ages, "sex": sex, "age_slope": age_slope, }
[docs] def make_connectome_example( n_subjects: int = 40, n_parcels: int = 50, n_networks: int = 5, group_effect: float = 0.3, seed: int = 42, ) -> dict[str, Any]: """Generate synthetic geometric connectomes for two groups. Parameters ---------- n_subjects : int n_parcels : int n_networks : int group_effect : float seed : int Returns ------- dict Keys: ``"connectomes"`` (S, R, R), ``"labels"`` (S,), ``"network_assignments"`` dict. """ rng = np.random.default_rng(seed) n_half = n_subjects // 2 # Base connectome with block structure. net_assign = {i: f"Net{i % n_networks}" for i in range(n_parcels)} block = np.zeros((n_parcels, n_parcels)) for i in range(n_parcels): for j in range(i + 1, n_parcels): same_net = (i % n_networks) == (j % n_networks) block[i, j] = 0.7 if same_net else 0.3 block[j, i] = block[i, j] connectomes = [] for s in range(n_subjects): noise = rng.normal(0, 0.1, (n_parcels, n_parcels)) noise = (noise + noise.T) / 2 C = block + noise if s >= n_half: # Add group effect to intra-network edges. for i in range(n_parcels): for j in range(i + 1, n_parcels): if (i % n_networks) == (j % n_networks): C[i, j] -= group_effect C[j, i] = C[i, j] np.fill_diagonal(C, 0) connectomes.append(C) return { "connectomes": np.array(connectomes), "labels": np.array([0] * n_half + [1] * n_half), "network_assignments": net_assign, }
[docs] def make_laterality_example( n_subjects: int = 40, n_features: int = 50, asymmetry: float = 0.3, seed: int = 42, ) -> dict[str, Any]: """Generate synthetic bilateral descriptors for asymmetry analysis. Parameters ---------- n_subjects : int n_features : int asymmetry : float Planted L > R asymmetry in patients. seed : int Returns ------- dict Keys: ``"left"`` (S, d), ``"right"`` (S, d), ``"labels"`` (S,). """ rng = np.random.default_rng(seed) n_half = n_subjects // 2 left = rng.normal(0, 1, (n_subjects, n_features)) right = rng.normal(0, 1, (n_subjects, n_features)) # Patients: left > right for some features. affected = rng.choice(n_features, n_features // 3, replace=False) left[n_half:, affected] += asymmetry return { "left": left, "right": right, "labels": np.array([0] * n_half + [1] * n_half), "affected_features": affected, }
# ====================================================================== # §2 TEMPLATE LOADERS # ======================================================================
[docs] def fetch_fsaverage( mesh: str = "pial", hemisphere: str = "lh", ) -> tuple[Vertices, Faces]: """Load fsaverage template surfaces from nibabel's bundled data. Parameters ---------- mesh : str ``"pial"``, ``"white"``, ``"inflated"``, ``"sphere"``. hemisphere : str ``"lh"`` or ``"rh"``. Returns ------- vertices, faces """ try: import nibabel as nib from nibabel import freesurfer as fs except ImportError as exc: raise ImportError("nibabel required for fsaverage.") from exc # nibabel ships fsaverage in its data directory. try: data_dir = Path(nib.__file__).parent / "freesurfer" / "data" surf_path = data_dir / "fsaverage" / "surf" / f"{hemisphere}.{mesh}" if surf_path.exists(): v, f = fs.read_geometry(str(surf_path)) return np.asarray(v, np.float64), np.asarray(f, np.int64) except Exception: pass # Fallback: try nilearn's fetch_surf_fsaverage. try: from nilearn.datasets import fetch_surf_fsaverage fsavg = fetch_surf_fsaverage(mesh="fsaverage") key = f"{mesh}_{hemisphere}" v, f = nib.load(fsavg[key]).darrays[0].data, nib.load(fsavg[key]).darrays[1].data return np.asarray(v, np.float64), np.asarray(f, np.int64) except Exception: pass raise FileNotFoundError( "Could not load fsaverage. Install nibabel or nilearn:\n pip install nibabel nilearn" )
# ====================================================================== # §3 EXAMPLE MESH/POINTCLOUD # ======================================================================
[docs] def example_sphere( n_lat: int = 30, n_lon: int = 60, radius: float = 50.0, ) -> tuple[Vertices, Faces]: """Quick sphere mesh for testing. Returns ------- vertices, faces """ from spectralbrain.statistics.surrogates import SyntheticMesh return SyntheticMesh(seed=0).sphere(n_lat, n_lon, radius)
[docs] def example_point_cloud( n_points: int = 1000, shape: str = "sphere", seed: int = 0, ) -> np.ndarray: """Quick point cloud for testing. Parameters ---------- n_points : int shape : str ``"sphere"``, ``"ellipsoid"``, ``"blob"``, ``"multi_cluster"``. seed : int Returns ------- ndarray, shape (n_points, 3) """ from spectralbrain.statistics.surrogates import SyntheticPointCloud gen = SyntheticPointCloud(seed=seed) if shape == "sphere": return gen.sphere(n_points) elif shape == "ellipsoid": return gen.ellipsoid(n_points) elif shape == "blob": return gen.blob(n_points) elif shape == "multi_cluster": return gen.multi_cluster(n_points) raise ValueError(f"Unknown shape: {shape!r}")
__all__ = [ "example_point_cloud", "example_sphere", "fetch_fsaverage", "make_connectome_example", "make_laterality_example", "make_normative_example", "make_two_group_example", ]