"""Bayesian statistical models for spectral morphometry.
Six models with a scikit-learn-like API: ``.fit()``, ``.predict()``,
``.score()``, ``.summary()``. All models delegate MCMC sampling
to the backends (PyMC NUTS, nutpie, or NumPyro).
Models
------
1. **HorseshoeRegression** — sparse regression for feature selection.
2. **BayesianGroupComparison** — BEST (Kruschke 2013) with HDI + ROPE.
3. **HierarchicalLinearModel** — multi-site random effects.
4. **GaussianProcessNormative** — GP age-trajectory normative.
5. **BayesianSpatialModel** — GMRF vertex-wise spatial prior.
6. **BayesianConnectome** — hierarchical connectome comparison.
Examples
--------
>>> model = HorseshoeRegression()
>>> model.fit(descriptors, clinical_scores)
>>> model.summary()
>>> predictions = model.predict(new_descriptors)
>>> model.score() # LOO-CV
Dependencies
------------
PyMC, ArviZ (optional, lazy-imported).
"""
from __future__ import annotations
import abc
from pathlib import Path
from typing import Any, Literal
import numpy as np
from spectralbrain.runtime import (
ConnectomeMatrix,
PathLike,
get_logger,
)
logger = get_logger(__name__)
# ======================================================================
# Lazy imports
# ======================================================================
def _require_pymc():
"""Lazy-import PyMC, raising a clear error if unavailable."""
try:
import pymc as pm
return pm
except ImportError as exc:
raise ImportError("PyMC is required for Bayesian models.\n pip install pymc") from exc
def _require_arviz():
"""Lazy-import ArviZ, raising a clear error if unavailable."""
try:
import arviz as az
return az
except ImportError as exc:
raise ImportError(
"ArviZ is required for Bayesian diagnostics.\n pip install arviz"
) from exc
# ======================================================================
# §0 BASE CLASS
# ======================================================================
[docs]
class BayesianModel(abc.ABC):
"""Abstract base for all SpectralBrain Bayesian models.
Subclasses implement :meth:`_build_model` to construct a PyMC
model, and optionally override :meth:`_build_posterior_predictive`
for custom prediction logic.
Attributes
----------
trace_ : arviz.InferenceData or None
Posterior samples (populated after ``.fit()``).
model_ : pymc.Model or None
The PyMC model object.
"""
def __init__(self) -> None:
"""Initialise the base Bayesian model with empty state."""
self.trace_: Any = None
self.model_: Any = None
self._is_fitted: bool = False
@abc.abstractmethod
def _build_model(self, X: np.ndarray, y: np.ndarray, **kwargs: Any) -> Any:
"""Construct and return a PyMC model."""
...
[docs]
def fit(
self,
X: np.ndarray,
y: np.ndarray,
*,
sampler: Literal["auto", "nuts", "nutpie", "numpyro", "blackjax"] = "auto",
draws: int = 2000,
tune: int = 1000,
chains: int = 4,
cores: int = 4,
target_accept: float = 0.95,
seed: int | None = 42,
**kwargs: Any,
) -> BayesianModel:
"""Fit the model via MCMC sampling.
Parameters
----------
X : ndarray, shape (n, d)
Feature matrix.
y : ndarray, shape (n,)
Target variable.
sampler : str
``"auto"`` tries nutpie → numpyro → nuts.
draws, tune, chains, cores : int
MCMC configuration.
target_accept : float
seed : int
**kwargs
Extra arguments passed to the sampler.
Returns
-------
self
"""
pm = _require_pymc()
X = np.asarray(X, dtype=np.float64)
y = np.asarray(y, dtype=np.float64)
# Subclasses pass model-specific data via instance attributes set in
# their own fit(); _build_model reads those, so no kwargs flow here.
self.model_ = self._build_model(X, y)
with self.model_:
if sampler in ("auto", "nuts"):
if sampler == "auto":
# Try nutpie; fall back to PyMC NUTS on a *logged* failure
# (never silently swallow a real model error).
try:
import nutpie
compiled = nutpie.compile_pymc_model(self.model_)
self.trace_ = nutpie.sample(
compiled, draws=draws, tune=tune, chains=chains, seed=seed
)
logger.info("Fitted with nutpie (%d draws × %d chains).", draws, chains)
self._is_fitted = True
return self
except ImportError:
logger.info("nutpie not installed; using PyMC NUTS.")
except Exception as exc:
logger.warning("nutpie failed (%s); falling back to PyMC NUTS.", exc)
# PyMC native NUTS (explicit sampler="nuts" or auto-fallback).
self.trace_ = pm.sample(
draws=draws,
tune=tune,
chains=chains,
cores=cores,
target_accept=target_accept,
random_seed=seed,
return_inferencedata=True,
progressbar=True,
**kwargs,
)
logger.info("Fitted with PyMC NUTS (%d draws × %d chains).", draws, chains)
elif sampler == "nutpie":
import nutpie
compiled = nutpie.compile_pymc_model(self.model_)
self.trace_ = nutpie.sample(
compiled,
draws=draws,
tune=tune,
chains=chains,
seed=seed,
)
logger.info("Fitted with nutpie (%d draws × %d chains).", draws, chains)
elif sampler == "numpyro":
import pymc.sampling.jax as pmjax
self.trace_ = pmjax.sample_numpyro_nuts(
draws=draws,
tune=tune,
chains=chains,
target_accept=target_accept,
random_seed=seed,
progress_bar=True,
**kwargs,
)
logger.info("Fitted with NumPyro (%d draws × %d chains).", draws, chains)
elif sampler == "blackjax":
import pymc.sampling.jax as pmjax
self.trace_ = pmjax.sample_blackjax_nuts(
draws=draws,
tune=tune,
chains=chains,
target_accept=target_accept,
random_seed=seed,
# BlackJAX's progress bar needs the optional `fastprogress`
# package; default off so it works on a base install.
progressbar=kwargs.pop("progressbar", False),
**kwargs,
)
logger.info("Fitted with BlackJAX (%d draws × %d chains).", draws, chains)
else:
raise ValueError(f"Unknown sampler: {sampler!r}")
self._is_fitted = True
return self
[docs]
def predict(
self,
X_new: np.ndarray,
*,
n_samples: int = 500,
seed: int | None = None,
) -> np.ndarray:
"""Generate posterior predictive samples.
Parameters
----------
X_new : ndarray, shape (m, d)
New feature values.
n_samples : int
Number of posterior predictive draws.
seed : int
Returns
-------
ndarray, shape (n_samples, m)
Posterior predictive samples.
"""
self._check_fitted()
pm = _require_pymc()
with self.model_:
pm.set_data({"X": X_new})
ppc = pm.sample_posterior_predictive(
self.trace_,
random_seed=seed,
predictions=True,
)
# Extract prediction array.
pred_vars = list(ppc.predictions.data_vars)
return ppc.predictions[pred_vars[0]].values.reshape(-1, X_new.shape[0])
[docs]
def score(
self,
method: Literal["loo", "waic"] = "loo",
) -> float:
"""Model comparison score via LOO-CV or WAIC.
Parameters
----------
method : str
``"loo"`` — Leave-One-Out via PSIS.
``"waic"`` — Widely Applicable Information Criterion.
Returns
-------
float
Expected log pointwise predictive density (elpd).
"""
self._check_fitted()
az = _require_arviz()
if method == "loo":
result = az.loo(self.trace_, pointwise=False)
return float(result.elpd_loo)
elif method == "waic":
result = az.waic(self.trace_, pointwise=False)
return float(result.elpd_waic)
raise ValueError(f"Unknown method: {method!r}")
[docs]
def summary(
self,
var_names: list[str] | None = None,
hdi_prob: float = 0.94,
) -> Any:
"""ArviZ summary table.
Parameters
----------
var_names : list of str, optional
hdi_prob : float
Returns
-------
pandas.DataFrame
"""
self._check_fitted()
az = _require_arviz()
# ArviZ >= 1.0 renamed ``hdi_prob`` to ``ci_prob``; stay compatible with both.
import inspect
params = inspect.signature(az.summary).parameters
kwargs: dict[str, Any] = {"var_names": var_names}
if "ci_prob" in params:
kwargs["ci_prob"] = hdi_prob
elif "hdi_prob" in params:
kwargs["hdi_prob"] = hdi_prob
return az.summary(self.trace_, **kwargs)
[docs]
def save(self, path: PathLike) -> Path:
"""Save trace to NetCDF."""
self._check_fitted()
_require_arviz()
out = Path(path)
self.trace_.to_netcdf(str(out))
logger.info("Trace saved → %s", out)
return out
[docs]
@classmethod
def load_trace(cls, path: PathLike) -> Any:
"""Load a saved trace."""
az = _require_arviz()
return az.from_netcdf(str(path))
def _check_fitted(self) -> None:
"""Raise RuntimeError if the model has not been fitted."""
if not self._is_fitted:
raise RuntimeError("Model not fitted yet. Call .fit(X, y) first.")
# ======================================================================
# §1 HORSESHOE REGRESSION
# ======================================================================
[docs]
class HorseshoeRegression(BayesianModel):
"""Sparse Bayesian regression with horseshoe prior.
The horseshoe prior (Carvalho, Polson & Scott 2009) provides
aggressive shrinkage of irrelevant coefficients toward zero
while leaving large effects unshrunk — ideal for selecting
which of 20+ spectral descriptors predict a clinical outcome.
Parameters
----------
tau_prior : float
Global shrinkage scale (smaller = more sparse).
Rule of thumb: p_eff / (d - p_eff) / sqrt(n), where
p_eff is expected number of relevant features.
Examples
--------
>>> model = HorseshoeRegression(tau_prior=0.1)
>>> model.fit(descriptors, clinical_scores)
>>> model.summary()
>>> model.predict(new_descriptors)
"""
def __init__(self, tau_prior: float = 0.1) -> None:
"""Initialise with global shrinkage prior scale.
Parameters
----------
tau_prior : float
Global shrinkage scale (smaller = more sparse).
"""
super().__init__()
self.tau_prior = tau_prior
def _build_model(self, X: np.ndarray, y: np.ndarray, **kw: Any) -> Any:
"""Build the horseshoe regression PyMC model."""
pm = _require_pymc()
_n, d = X.shape
with pm.Model() as model:
X_data = pm.Data("X", X)
y_data = pm.Data("y_obs", y)
# Global shrinkage.
tau = pm.HalfCauchy("tau", beta=self.tau_prior)
# Local shrinkage per coefficient.
lam = pm.HalfCauchy("lambda", beta=1.0, shape=d)
# Coefficients.
beta = pm.Normal("beta", mu=0, sigma=tau * lam, shape=d)
# Intercept.
alpha = pm.Normal("alpha", mu=y.mean(), sigma=y.std() * 2)
# Noise.
sigma = pm.HalfNormal("sigma", sigma=y.std())
# Likelihood.
mu = alpha + pm.math.dot(X_data, beta)
pm.Normal("y", mu=mu, sigma=sigma, observed=y_data)
return model
[docs]
def feature_importance(self) -> np.ndarray:
"""Posterior mean of |β| — higher = more important.
Returns
-------
ndarray, shape (d,)
"""
self._check_fitted()
_require_arviz()
beta = self.trace_.posterior["beta"].values
return np.abs(beta).mean(axis=(0, 1))
# ======================================================================
# §2 BAYESIAN GROUP COMPARISON (BEST — Kruschke 2013)
# ======================================================================
[docs]
class BayesianGroupComparison(BayesianModel):
"""Bayesian Estimation Supersedes the T-test (BEST).
Estimates the full posterior distribution of group means,
standard deviations, effect size, and their differences.
Reports HDI (Highest Density Interval) and probability of
the difference exceeding a ROPE.
Parameters
----------
rope : tuple of float
Region Of Practical Equivalence. Default (-0.1, 0.1)
in units of the pooled standard deviation.
Examples
--------
>>> model = BayesianGroupComparison(rope=(-0.1, 0.1))
>>> model.fit(group_a_values, group_b_values)
>>> model.summary()
>>> model.effect_size_posterior()
"""
def __init__(self, rope: tuple[float, float] = (-0.1, 0.1)) -> None:
"""Initialise with ROPE bounds.
Parameters
----------
rope : tuple of float
Region of practical equivalence bounds.
"""
super().__init__()
self.rope = rope
[docs]
def fit(self, group_a: np.ndarray, group_b: np.ndarray, **kwargs) -> BayesianGroupComparison:
"""Fit BEST model.
Parameters
----------
group_a, group_b : ndarray, shape (n_a,) and (n_b,)
"""
a = np.asarray(group_a, dtype=np.float64).ravel()
b = np.asarray(group_b, dtype=np.float64).ravel()
# Store for predict.
self._a = a
self._b = b
# Pack into X, y format for parent .fit().
X = np.zeros((len(a) + len(b), 1))
y = np.concatenate([a, b])
return super().fit(X, y, **kwargs)
def _build_model(self, X: np.ndarray, y: np.ndarray, **kw: Any) -> Any:
"""Build the BEST (Kruschke 2013) PyMC model."""
pm = _require_pymc()
a, b = self._a, self._b
pooled = np.concatenate([a, b])
with pm.Model() as model:
# Group means.
mu_a = pm.Normal("mu_a", mu=pooled.mean(), sigma=pooled.std() * 2)
mu_b = pm.Normal("mu_b", mu=pooled.mean(), sigma=pooled.std() * 2)
# Group standard deviations.
sigma_a = pm.HalfNormal("sigma_a", sigma=pooled.std() * 2)
sigma_b = pm.HalfNormal("sigma_b", sigma=pooled.std() * 2)
# Normality parameter (Student-t df).
nu = pm.Exponential("nu_minus1", lam=1 / 29.0) + 1
# Likelihoods.
pm.StudentT("obs_a", nu=nu, mu=mu_a, sigma=sigma_a, observed=a)
pm.StudentT("obs_b", nu=nu, mu=mu_b, sigma=sigma_b, observed=b)
# Derived quantities.
pm.Deterministic("diff_means", mu_a - mu_b)
pm.Deterministic("diff_stds", sigma_a - sigma_b)
pooled_sd = pm.math.sqrt((sigma_a**2 + sigma_b**2) / 2)
pm.Deterministic("effect_size", (mu_a - mu_b) / pooled_sd)
return model
[docs]
def predict(self, X_new=None, **kwargs):
"""Not applicable for group comparison; use effect_size_posterior()."""
raise NotImplementedError(
"BayesianGroupComparison does not support predict(). "
"Use .effect_size_posterior() or .summary() instead."
)
[docs]
def effect_size_posterior(self) -> np.ndarray:
"""Posterior samples of Cohen's d.
Returns
-------
ndarray, shape (n_samples,)
"""
self._check_fitted()
return self.trace_.posterior["effect_size"].values.ravel()
[docs]
def rope_probability(self) -> dict[str, float]:
"""Probability of effect size in, below, and above ROPE.
Returns
-------
dict
Keys: ``"p_rope"`` (inside), ``"p_below"`` (below),
``"p_above"`` (above ROPE).
"""
d = self.effect_size_posterior()
lo, hi = self.rope
return {
"p_below": float((d < lo).mean()),
"p_rope": float(((d >= lo) & (d <= hi)).mean()),
"p_above": float((d > hi).mean()),
}
# ======================================================================
# §3 HIERARCHICAL LINEAR MODEL
# ======================================================================
[docs]
class HierarchicalLinearModel(BayesianModel):
"""Multi-site hierarchical linear model with random effects.
Models spectral descriptors with fixed effects (group, age, sex)
and random intercepts/slopes per site, handling batch effects
within the model rather than post-hoc harmonisation.
y ~ α + β·X + u_site + ε
Parameters
----------
random_effects : str
``"intercept"`` — random intercept per site.
``"slope"`` — random intercept + slope per site.
Examples
--------
>>> model = HierarchicalLinearModel(random_effects="intercept")
>>> model.fit(X, y, site_labels=sites)
"""
def __init__(
self,
random_effects: Literal["intercept", "slope"] = "intercept",
) -> None:
"""Initialise with random-effects structure.
Parameters
----------
random_effects : ``"intercept"`` or ``"slope"``
Type of random effects per site.
"""
super().__init__()
self.random_effects = random_effects
[docs]
def fit(self, X, y, *, site_labels: np.ndarray, **kwargs):
"""Fit with site labels for random effects."""
self._site_labels = np.asarray(site_labels)
self._unique_sites = np.unique(self._site_labels)
self._site_idx = np.searchsorted(
self._unique_sites,
self._site_labels,
)
return super().fit(X, y, **kwargs)
def _build_model(self, X: np.ndarray, y: np.ndarray, **kw: Any) -> Any:
"""Build the hierarchical linear PyMC model with site random effects."""
pm = _require_pymc()
_n, d = X.shape
n_sites = len(self._unique_sites)
with pm.Model() as model:
X_data = pm.Data("X", X)
y_data = pm.Data("y_obs", y)
site_idx = pm.Data("site_idx", self._site_idx)
# Fixed effects.
alpha = pm.Normal("alpha", mu=y.mean(), sigma=y.std() * 2)
beta = pm.Normal("beta", mu=0, sigma=1, shape=d)
# Random effects by site.
sigma_site = pm.HalfNormal("sigma_site", sigma=y.std())
u_site = pm.Normal("u_site", mu=0, sigma=sigma_site, shape=n_sites)
# Linear predictor.
mu = alpha + pm.math.dot(X_data, beta) + u_site[site_idx]
if self.random_effects == "slope" and d > 0:
sigma_slope = pm.HalfNormal("sigma_slope", sigma=1.0)
beta_site = pm.Normal(
"beta_site",
mu=0,
sigma=sigma_slope,
shape=(n_sites, d),
)
mu = mu + (X_data * beta_site[site_idx]).sum(axis=1)
# Noise.
sigma = pm.HalfNormal("sigma", sigma=y.std())
pm.Normal("y", mu=mu, sigma=sigma, observed=y_data)
return model
[docs]
def site_effects(self) -> np.ndarray:
"""Posterior mean of random intercepts per site.
Returns
-------
ndarray, shape (n_sites,)
"""
self._check_fitted()
return self.trace_.posterior["u_site"].values.mean(axis=(0, 1))
# ======================================================================
# §4 GAUSSIAN PROCESS NORMATIVE
# ======================================================================
[docs]
class GaussianProcessNormative(BayesianModel):
"""GP-based normative model for age trajectories.
Fits a Gaussian Process over age (or any continuous covariate)
to model the normative distribution of a spectral descriptor.
Individual deviations are computed as z-scores from the posterior
predictive.
Parameters
----------
kernel : str
``"matern32"`` or ``"matern52"`` or ``"rbf"``.
lengthscale_prior : float
Prior mean for the GP lengthscale (in years for age).
Examples
--------
>>> gp = GaussianProcessNormative(kernel="matern52")
>>> gp.fit(ages_controls[:, None], descriptor_controls)
>>> z_patient = gp.deviation(age_patient, descriptor_patient)
"""
def __init__(
self,
kernel: Literal["matern32", "matern52", "rbf"] = "matern52",
lengthscale_prior: float = 10.0,
) -> None:
"""Initialise with kernel type and lengthscale prior.
Parameters
----------
kernel : str
GP kernel: ``"matern32"``, ``"matern52"``, or ``"rbf"``.
lengthscale_prior : float
Prior mean for the GP lengthscale.
"""
super().__init__()
self.kernel = kernel
self.lengthscale_prior = lengthscale_prior
self._X_train: np.ndarray | None = None
self._y_train: np.ndarray | None = None
def _build_model(self, X: np.ndarray, y: np.ndarray, **kw: Any) -> Any:
"""Build the GP normative PyMC model."""
pm = _require_pymc()
import pymc.gp as gp
self._X_train = X.copy()
self._y_train = y.copy()
with pm.Model() as model:
# GP hyperpriors.
ls = pm.InverseGamma(
"lengthscale",
alpha=5,
beta=5 * self.lengthscale_prior,
)
eta = pm.HalfNormal("eta", sigma=y.std())
sigma = pm.HalfNormal("sigma", sigma=y.std() * 0.5)
# Kernel.
if self.kernel == "matern32":
cov = eta**2 * gp.cov.Matern32(input_dim=X.shape[1], ls=ls)
elif self.kernel == "matern52":
cov = eta**2 * gp.cov.Matern52(input_dim=X.shape[1], ls=ls)
elif self.kernel == "rbf":
cov = eta**2 * gp.cov.ExpQuad(input_dim=X.shape[1], ls=ls)
else:
raise ValueError(f"Unknown kernel: {self.kernel!r}")
# Marginal GP.
self._gp = gp.Marginal(cov_func=cov)
self._gp.marginal_likelihood("y", X=X, y=y, sigma=sigma)
return model
[docs]
def predict(self, X_new: np.ndarray, **kwargs) -> tuple[np.ndarray, np.ndarray]:
"""Posterior predictive mean and std at new points.
Parameters
----------
X_new : ndarray, shape (m, d)
Returns
-------
mean : ndarray, shape (m,)
std : ndarray, shape (m,)
"""
self._check_fitted()
pm = _require_pymc()
import uuid
pred_name = f"f_pred_{uuid.uuid4().hex[:8]}"
with self.model_:
self._gp.conditional(pred_name, X_new)
ppc = pm.sample_posterior_predictive(
self.trace_,
var_names=[pred_name],
random_seed=kwargs.get("seed"),
)
samples = ppc.posterior_predictive[pred_name].values
samples = samples.reshape(-1, X_new.shape[0])
return samples.mean(axis=0), samples.std(axis=0)
[docs]
def deviation(
self,
age: float,
observed_value: float,
) -> float:
"""Z-score deviation of an individual from the normative.
Parameters
----------
age : float
observed_value : float
Returns
-------
float
Z-score (positive = above normative).
"""
X_new = np.array([[age]])
mean, std = self.predict(X_new)
return float((observed_value - mean[0]) / (std[0] + 1e-30))
# ======================================================================
# §5 BAYESIAN SPATIAL MODEL
# ======================================================================
[docs]
class BayesianSpatialModel(BayesianModel):
"""Vertex-wise Bayesian model with spatial GMRF prior.
Places a Gaussian Markov Random Field prior on the vertex-level
effects, so neighbouring vertices share information. This is
Bayesian spatial smoothing — more principled than Gaussian
kernel pre-smoothing.
Parameters
----------
spatial_strength : float
Precision of the GMRF prior (higher = more spatial
smoothing).
Examples
--------
>>> model = BayesianSpatialModel(spatial_strength=10.0)
>>> model.fit(group_labels, vertex_descriptors,
... adjacency=mesh_adjacency)
"""
def __init__(self, spatial_strength: float = 10.0) -> None:
"""Initialise with GMRF spatial precision.
Parameters
----------
spatial_strength : float
Precision of the spatial prior (higher = more smoothing).
"""
super().__init__()
self.spatial_strength = spatial_strength
[docs]
def fit(self, group_labels, vertex_data, *, adjacency, **kwargs):
"""Fit spatial model.
Parameters
----------
group_labels : ndarray, shape (S,)
Group assignment (0/1) per subject.
vertex_data : ndarray, shape (S, N)
Per-subject vertex-wise descriptor values.
adjacency : sparse matrix, shape (N, N)
Mesh or kNN adjacency.
"""
self._adjacency = adjacency
self._group_labels = np.asarray(group_labels)
self._vertex_data = np.asarray(vertex_data, dtype=np.float64)
# Build as X (group) → y (mean vertex descriptor)
X = group_labels.reshape(-1, 1).astype(np.float64)
y = vertex_data.mean(axis=1) # collapse vertices for base .fit()
return super().fit(X, y, **kwargs)
def _build_model(self, X: np.ndarray, y: np.ndarray, **kw: Any) -> Any:
"""Build the GMRF spatial PyMC model."""
pm = _require_pymc()
import scipy.sparse as sp
N = self._vertex_data.shape[1]
S = len(self._group_labels)
# Build GMRF precision from adjacency (graph Laplacian + diagonal).
adj = sp.csr_matrix(self._adjacency)
degree = np.asarray(adj.sum(axis=1)).ravel()
Q = sp.diags(degree) - adj + sp.eye(N) * self.spatial_strength
with pm.Model() as model:
# Global intercept and group effect.
alpha = pm.Normal("alpha", mu=0, sigma=10)
beta_group = pm.Normal("beta_group", mu=0, sigma=5)
# Vertex-level group effect with spatial prior.
# Simplified: model group difference per vertex as
# spatially smooth via CAR-like prior.
sigma_spatial = pm.HalfNormal("sigma_spatial", sigma=1)
tau_spatial = 1 / (sigma_spatial**2)
# Vertex-level effects (simplified as Normal with spatial std).
vertex_effect = pm.Normal(
"vertex_effect",
mu=0,
sigma=sigma_spatial,
shape=N,
)
# Spatial penalty as potential (soft GMRF).
Q_dense = Q.toarray()
spatial_penalty = (
-0.5 * tau_spatial * pm.math.dot(vertex_effect, pm.math.dot(Q_dense, vertex_effect))
)
pm.Potential("spatial_prior", spatial_penalty)
# Likelihood: per-subject, per-vertex.
sigma_obs = pm.HalfNormal("sigma_obs", sigma=self._vertex_data.std())
group_float = self._group_labels.astype(np.float64)
for s in range(S):
mu_s = alpha + beta_group * group_float[s] + vertex_effect
pm.Normal(
f"y_{s}",
mu=mu_s,
sigma=sigma_obs,
observed=self._vertex_data[s],
)
return model
[docs]
def vertex_effect_map(self) -> np.ndarray:
"""Posterior mean of vertex-level group effect.
Returns
-------
ndarray, shape (N,)
"""
self._check_fitted()
return self.trace_.posterior["vertex_effect"].values.mean(axis=(0, 1))
# ======================================================================
# §6 BAYESIAN CONNECTOME COMPARISON
# ======================================================================
[docs]
class BayesianConnectome(BayesianModel):
"""Hierarchical Bayesian model for connectome comparison.
Models each entry of the geometric connectome matrix with a
hierarchical prior, testing whether connection strengths differ
between groups while sharing information across edges.
Parameters
----------
shrinkage : float
Hierarchical shrinkage strength.
Examples
--------
>>> model = BayesianConnectome()
>>> model.fit(connectomes_patients, connectomes_controls)
>>> diff = model.edge_difference_posterior()
"""
def __init__(self, shrinkage: float = 1.0) -> None:
"""Initialise with hierarchical shrinkage strength.
Parameters
----------
shrinkage : float
Prior shrinkage for edge-level differences.
"""
super().__init__()
self.shrinkage = shrinkage
[docs]
def fit(
self,
group_a_connectomes: np.ndarray,
group_b_connectomes: np.ndarray,
**kwargs,
) -> BayesianConnectome:
"""Fit connectome comparison model.
Parameters
----------
group_a_connectomes : ndarray, shape (n_a, R, R)
group_b_connectomes : ndarray, shape (n_b, R, R)
"""
a = np.asarray(group_a_connectomes, dtype=np.float64)
b = np.asarray(group_b_connectomes, dtype=np.float64)
self._conn_a = a
self._conn_b = b
self._R = a.shape[1]
# Extract upper triangle for modeling.
triu_idx = np.triu_indices(self._R, k=1)
self._triu_idx = triu_idx
len(triu_idx[0])
# Stack into X (group indicator), y (edge values).
a_edges = np.array([c[triu_idx] for c in a]) # (n_a, n_edges)
b_edges = np.array([c[triu_idx] for c in b]) # (n_b, n_edges)
self._a_edges = a_edges
self._b_edges = b_edges
X = np.zeros((len(a) + len(b), 1))
y = np.concatenate([a_edges.mean(axis=1), b_edges.mean(axis=1)])
return super().fit(X, y, **kwargs)
def _build_model(self, X: np.ndarray, y: np.ndarray, **kw: Any) -> Any:
"""Build the hierarchical connectome comparison PyMC model."""
pm = _require_pymc()
n_edges = self._a_edges.shape[1]
a_mean = self._a_edges.mean(axis=0)
b_mean = self._b_edges.mean(axis=0)
all_edges = np.concatenate([self._a_edges, self._b_edges])
with pm.Model() as model:
# Hierarchical prior on edge-level differences.
mu_diff = pm.Normal("mu_diff", mu=0, sigma=self.shrinkage)
sigma_diff = pm.HalfNormal("sigma_diff", sigma=self.shrinkage)
# Per-edge difference.
edge_diff = pm.Normal(
"edge_diff",
mu=mu_diff,
sigma=sigma_diff,
shape=n_edges,
)
# Group means.
grand_mean = pm.Normal(
"grand_mean",
mu=all_edges.mean(),
sigma=all_edges.std(),
shape=n_edges,
)
sigma_obs = pm.HalfNormal("sigma_obs", sigma=all_edges.std())
# Likelihoods.
pm.Normal(
"obs_a",
mu=grand_mean + edge_diff / 2,
sigma=sigma_obs,
observed=a_mean,
)
pm.Normal(
"obs_b",
mu=grand_mean - edge_diff / 2,
sigma=sigma_obs,
observed=b_mean,
)
return model
[docs]
def predict(self, X_new=None, **kwargs):
"""Not applicable for connectome comparison; use edge_difference_posterior()."""
raise NotImplementedError(
"BayesianConnectome does not support predict(). "
"Use .edge_difference_posterior() instead."
)
[docs]
def edge_difference_posterior(self) -> np.ndarray:
"""Posterior mean of per-edge group difference.
Returns
-------
ndarray, shape (n_edges,)
"""
self._check_fitted()
return self.trace_.posterior["edge_diff"].values.mean(axis=(0, 1))
[docs]
def edge_difference_matrix(self) -> ConnectomeMatrix:
"""Reconstruct the difference as a symmetric matrix.
Returns
-------
ndarray, shape (R, R)
"""
diff = self.edge_difference_posterior()
mat = np.zeros((self._R, self._R))
mat[self._triu_idx] = diff
mat += mat.T
return mat
# ======================================================================
__all__: list[str] = [
"BayesianConnectome",
"BayesianGroupComparison",
"BayesianModel",
"BayesianSpatialModel",
"GaussianProcessNormative",
"HierarchicalLinearModel",
"HorseshoeRegression",
]