Source code for spectralbrain.spectral.distances

"""Spectral distance metrics between shapes and between points.

Two categories of distance:

**Shape-to-shape** — compare two *different* shapes via their
eigenvalue spectra (no point correspondence needed):

- :func:`wesd` — Weighted Spectral Distance (Konukoglu et al. 2013)
- :func:`shapedna_distance` — Euclidean / Mahalanobis on ShapeDNA

**Point-to-point** — distances *within* a single shape, computed
from the eigenpairs of that shape:

- :func:`biharmonic_distance` — Lipman, Rustamov & Funkhouser 2010
- :func:`commute_time_distance` — random-walk commute time
- :func:`diffusion_distance` — Coifman & Lafon 2006

All point-to-point distances share a unifying form:

.. math::

    d^2(x, y) = \\sum_i g(\\lambda_i)\\,
                (\\varphi_i(x) - \\varphi_i(y))^2

where g(λ) is a spectral filter that defines the metric.
"""

from __future__ import annotations

from typing import Literal

import numpy as np

from spectralbrain.core.base import SpectralDecomposition
from spectralbrain.runtime import (
    DistanceMatrix,
    GlobalDescriptor,
    get_logger,
    progress_simple,
)

logger = get_logger(__name__)


# ======================================================================
# §1  UNIFIED SPECTRAL DISTANCE KERNEL
# ======================================================================


def _spectral_distance_matrix(
    eigenvectors: np.ndarray,
    weights: np.ndarray,
    *,
    indices: np.ndarray | None = None,
) -> DistanceMatrix:
    """Compute pairwise spectral distance matrix.

    d²(x, y) = Σᵢ wᵢ · (φᵢ(x) − φᵢ(y))²

    This is equivalent to the Euclidean distance in the
    weight-scaled eigenfunction embedding.

    Parameters
    ----------
    eigenvectors : ndarray, shape (N, k)
    weights : ndarray, shape (k,)
        Spectral filter weights g(λᵢ).
    indices : ndarray of int, optional
        If given, compute distances only from these source vertices
        to all others.  Returns shape (len(indices), N).

    Returns
    -------
    ndarray, shape (M, N) or (N, N)
    """
    # Weighted embedding: Ψ(x) = φ(x) · √w
    sqrt_w = np.sqrt(np.clip(weights, 0.0, None))  # (k,)
    embedding = eigenvectors * sqrt_w[None, :]  # (N, k)

    if indices is not None:
        sources = embedding[indices]  # (M, k)
        # Squared Euclidean: ||a-b||² = ||a||² + ||b||² - 2·a·b
        sq_src = np.sum(sources**2, axis=1, keepdims=True)  # (M, 1)
        sq_all = np.sum(embedding**2, axis=1, keepdims=True).T  # (1, N)
        cross = sources @ embedding.T  # (M, N)
        D2 = sq_src + sq_all - 2 * cross  # (M, N)
    else:
        sq = np.sum(embedding**2, axis=1)  # (N,)
        cross = embedding @ embedding.T  # (N, N)
        D2 = sq[:, None] + sq[None, :] - 2 * cross  # (N, N)

    return np.sqrt(np.clip(D2, 0.0, None))


# ======================================================================
# §2  SHAPE-TO-SHAPE DISTANCES
# ======================================================================


[docs] def wesd( dna_a: GlobalDescriptor, dna_b: GlobalDescriptor, *, p: float = 2.0, normalize: bool = True, ) -> float: """Weighted Spectral Distance between two ShapeDNA vectors. A pseudometric with convergence guarantees — the series converges for p > d/2 where d is the dimension of the manifold (d = 2 for surfaces, so p > 1 suffices). .. math:: \\text{WESD}^p(\\Omega_1, \\Omega_2) = \\left( \\sum_{i=1}^{k} \\frac{|\\lambda_i^{(1)} - \\lambda_i^{(2)}|} {\\lambda_i^{(1)} \\cdot \\lambda_i^{(2)}} \\right)^{1/p} Parameters ---------- dna_a, dna_b : ndarray, shape (d,) ShapeDNA eigenvalue sequences (skip λ₀). p : float Exponent (must be > 1 for 2D surfaces). normalize : bool Map to [0, 1) via WESD / (1 + WESD). Returns ------- float WESD distance. References ---------- Konukoglu E, Glocker B, Criminisi A, Pohl KM. WESD — Weighted Spectral Distance for measuring shape dissimilarity. *IEEE TPAMI* 35(9):2284–2297, 2013. """ a = np.asarray(dna_a, dtype=np.float64) b = np.asarray(dna_b, dtype=np.float64) # Truncate to common length. n = min(len(a), len(b)) a, b = a[:n], b[:n] # Both must be positive (eigenvalues with λ₀ removed). a = np.clip(a, 1e-20, None) b = np.clip(b, 1e-20, None) terms = np.abs(a - b) / (a * b) raw = np.sum(terms**p) ** (1.0 / p) if normalize: return float(raw / (1.0 + raw)) return float(raw)
[docs] def wesd_matrix( dna_collection: np.ndarray, *, p: float = 2.0, normalize: bool = True, ) -> DistanceMatrix: """Pairwise WESD matrix for a collection of ShapeDNA vectors. Parameters ---------- dna_collection : ndarray, shape (S, d) S shapes, each with d-dimensional ShapeDNA. p : float normalize : bool Returns ------- ndarray, shape (S, S) Symmetric WESD distance matrix. """ S = dna_collection.shape[0] D = np.zeros((S, S), dtype=np.float64) total_pairs = S * (S - 1) // 2 with progress_simple("WESD matrix", total=total_pairs) as tick: for i in range(S): for j in range(i + 1, S): d = wesd( dna_collection[i], dna_collection[j], p=p, normalize=normalize, ) D[i, j] = d D[j, i] = d tick(1) return D
[docs] def shapedna_distance( dna_a: GlobalDescriptor, dna_b: GlobalDescriptor, *, metric: Literal["euclidean", "mahalanobis", "cosine"] = "euclidean", cov_inv: np.ndarray | None = None, ) -> float: """Simple distance between two ShapeDNA vectors. Parameters ---------- dna_a, dna_b : ndarray, shape (d,) metric : str ``"euclidean"`` — L2 distance. ``"mahalanobis"`` — requires *cov_inv*. ``"cosine"`` — 1 − cos(a, b). cov_inv : ndarray, shape (d, d), optional Inverse covariance matrix for Mahalanobis. Returns ------- float """ a = np.asarray(dna_a, dtype=np.float64) b = np.asarray(dna_b, dtype=np.float64) n = min(len(a), len(b)) a, b = a[:n], b[:n] if metric == "euclidean": return float(np.linalg.norm(a - b)) elif metric == "mahalanobis": if cov_inv is None: raise ValueError("cov_inv required for Mahalanobis distance.") diff = a - b return float(np.sqrt(diff @ cov_inv[:n, :n] @ diff)) elif metric == "cosine": dot = np.dot(a, b) na = np.linalg.norm(a) nb = np.linalg.norm(b) if na < 1e-20 or nb < 1e-20: return 1.0 return float(1.0 - dot / (na * nb)) else: raise ValueError(f"Unknown metric: {metric!r}")
# ====================================================================== # §3 BIHARMONIC DISTANCE (Lipman, Rustamov & Funkhouser, 2010) # ======================================================================
[docs] def biharmonic_distance( decomp: SpectralDecomposition, *, indices: np.ndarray | None = None, ) -> DistanceMatrix: """Biharmonic distance — parameter-free intrinsic metric. .. math:: d_B^2(x, y) = \\sum_{i=1}^{k} \\frac{(\\varphi_i(x) - \\varphi_i(y))^2}{\\lambda_i^2} Smooth, locally isotropic, globally shape-aware, robust to topological noise. No tuneable parameters. Parameters ---------- decomp : SpectralDecomposition indices : ndarray of int, optional Compute distances only from these source vertices. Returns ------- ndarray, shape (M, N) or (N, N) References ---------- Lipman Y, Rustamov RM, Funkhouser TA. Biharmonic distance. *ACM TOG* 29(3):27, 2010. """ evals = decomp.eigenvalues evecs = decomp.eigenvectors nz = evals > 1e-10 weights = np.zeros_like(evals) weights[nz] = 1.0 / (evals[nz] ** 2) D = _spectral_distance_matrix(evecs, weights, indices=indices) logger.debug("Biharmonic distance: shape %s", D.shape) return D
# ====================================================================== # §4 COMMUTE-TIME DISTANCE # ======================================================================
[docs] def commute_time_distance( decomp: SpectralDecomposition, *, indices: np.ndarray | None = None, warn_large: bool = True, ) -> DistanceMatrix: """Commute-time distance from random walk theory. .. math:: d_{CT}^2(x, y) = \\sum_{i=1}^{k} \\frac{(\\varphi_i(x) - \\varphi_i(y))^2}{\\lambda_i} .. warning:: Degenerates on large graphs (N > 50 k) — converges to a function of vertex degree only (von Luxburg et al. 2010). Use biharmonic distance instead for large meshes. Parameters ---------- decomp : SpectralDecomposition indices : ndarray of int, optional warn_large : bool Emit warning if N > 50 000. Returns ------- ndarray, shape (M, N) or (N, N) """ if warn_large and decomp.n_vertices > 50_000: logger.warning( "Commute-time distance degenerates on large graphs " "(N=%d > 50k). Consider biharmonic_distance() instead.", decomp.n_vertices, ) evals = decomp.eigenvalues evecs = decomp.eigenvectors nz = evals > 1e-10 weights = np.zeros_like(evals) weights[nz] = 1.0 / evals[nz] D = _spectral_distance_matrix(evecs, weights, indices=indices) logger.debug("Commute-time distance: shape %s", D.shape) return D
# ====================================================================== # §5 DIFFUSION DISTANCE (Coifman & Lafon, 2006) # ======================================================================
[docs] def diffusion_distance( decomp: SpectralDecomposition, t: float, *, indices: np.ndarray | None = None, ) -> DistanceMatrix: """Diffusion distance — multi-scale intrinsic metric. .. math:: D_t^2(x, y) = \\sum_{i=1}^{k} e^{-2\\lambda_i t}\\, (\\varphi_i(x) - \\varphi_i(y))^2 Small *t* ≈ geodesic; large *t* ≈ global diffusion. Parameters ---------- decomp : SpectralDecomposition t : float Diffusion time scale. indices : ndarray of int, optional Returns ------- ndarray, shape (M, N) or (N, N) References ---------- Coifman RR, Lafon S. Diffusion maps. *Applied and Computational Harmonic Analysis* 21(1):5–30, 2006. """ evals = decomp.eigenvalues evecs = decomp.eigenvectors weights = np.exp(-2.0 * evals * t) D = _spectral_distance_matrix(evecs, weights, indices=indices) logger.debug("Diffusion distance (t=%.4f): shape %s", t, D.shape) return D
[docs] def diffusion_distance_multiscale( decomp: SpectralDecomposition, t_values: np.ndarray, *, indices: np.ndarray | None = None, ) -> np.ndarray: """Diffusion distance at multiple time scales. Parameters ---------- decomp : SpectralDecomposition t_values : ndarray, shape (T,) Time scales. indices : ndarray of int, optional Returns ------- ndarray, shape (T, M, N) or (T, N, N) Distance matrix per time scale. """ results = [] with progress_simple("Diffusion distance", total=len(t_values)) as tick: for t in t_values: D = diffusion_distance(decomp, t, indices=indices) results.append(D) tick(1) return np.stack(results, axis=0)
# ====================================================================== # §6 SPECTRAL DISTANCE BETWEEN DESCRIPTORS (for geometric connectome) # ======================================================================
[docs] def descriptor_distance( desc_a: np.ndarray, desc_b: np.ndarray, *, method: Literal[ "wasserstein", "mmd", "euclidean", "cosine", "correlation", ] = "wasserstein", **kwargs, ) -> float: """Distance between two descriptor distributions. Used to build the geometric connectome: for each pair of parcels, compute the distance between their descriptor distributions. Parameters ---------- desc_a : ndarray, shape (N_a,) or (N_a, T) Descriptor values at vertices of parcel A. desc_b : ndarray, shape (N_b,) or (N_b, T) Descriptor values at vertices of parcel B. method : str ``"wasserstein"`` — 1D Wasserstein (Earth Mover's Distance). ``"mmd"`` — Maximum Mean Discrepancy with Gaussian kernel. ``"euclidean"`` — L2 between distribution means. ``"cosine"`` — cosine distance between means. ``"correlation"`` — 1 − Pearson r between aggregated features. Returns ------- float Notes ----- For 1D descriptors (ScalarMap), Wasserstein is exact and O(N log N). For multi-dimensional descriptors (DescriptorMatrix), the columns are treated independently and distances are averaged. """ a = np.asarray(desc_a, dtype=np.float64) b = np.asarray(desc_b, dtype=np.float64) if method == "wasserstein": return _wasserstein_1d_multi(a, b) elif method == "mmd": return _mmd_gaussian(a, b, **kwargs) elif method == "euclidean": ma = a.mean(axis=0) if a.ndim > 1 else np.array([a.mean()]) mb = b.mean(axis=0) if b.ndim > 1 else np.array([b.mean()]) return float(np.linalg.norm(ma - mb)) elif method == "cosine": ma = a.mean(axis=0) if a.ndim > 1 else np.array([a.mean()]) mb = b.mean(axis=0) if b.ndim > 1 else np.array([b.mean()]) dot = np.dot(ma, mb) na, nb = np.linalg.norm(ma), np.linalg.norm(mb) if na < 1e-20 or nb < 1e-20: return 1.0 return float(1.0 - dot / (na * nb)) elif method == "correlation": ma = a.mean(axis=0) if a.ndim > 1 else np.array([a.mean()]) mb = b.mean(axis=0) if b.ndim > 1 else np.array([b.mean()]) if len(ma) < 2: return 0.0 r = np.corrcoef(ma, mb)[0, 1] return float(1.0 - r) if np.isfinite(r) else 1.0 else: raise ValueError(f"Unknown method: {method!r}")
def _wasserstein_1d_multi(a: np.ndarray, b: np.ndarray) -> float: """1D Wasserstein averaged over columns (if multi-dimensional).""" from scipy.stats import wasserstein_distance if a.ndim == 1 and b.ndim == 1: return float(wasserstein_distance(a, b)) if a.ndim == 1: a = a[:, None] if b.ndim == 1: b = b[:, None] T = min(a.shape[1], b.shape[1]) dists = [wasserstein_distance(a[:, t], b[:, t]) for t in range(T)] return float(np.mean(dists)) def _mmd_gaussian( a: np.ndarray, b: np.ndarray, *, sigma: float | None = None, ) -> float: """Maximum Mean Discrepancy with Gaussian kernel.""" if a.ndim == 1: a = a[:, None] if b.ndim == 1: b = b[:, None] if sigma is None: combined = np.vstack([a, b]) # Median heuristic. from scipy.spatial.distance import pdist dists = pdist(combined[: min(500, len(combined))]) sigma = float(np.median(dists)) if len(dists) > 0 else 1.0 sigma = max(sigma, 1e-6) gamma = 1.0 / (2.0 * sigma**2) def _k(X: np.ndarray, Y: np.ndarray) -> float: """Kernel function for the shell distance computation.""" D2 = ( np.sum(X**2, axis=1, keepdims=True) + np.sum(Y**2, axis=1, keepdims=True).T - 2 * X @ Y.T ) return float(np.mean(np.exp(-gamma * D2))) mmd2 = _k(a, a) + _k(b, b) - 2 * _k(a, b) return float(np.sqrt(max(mmd2, 0.0))) # ====================================================================== # §7 CONNECTOME BUILDER # ======================================================================
[docs] def build_geometric_connectome( parcel_descriptors: dict, *, method: Literal[ "wasserstein", "mmd", "euclidean", "cosine", "correlation", ] = "wasserstein", **kwargs, ) -> tuple[DistanceMatrix, list]: """Build a ROI × ROI geometric connectome from parcel descriptors. For each pair of parcels, computes the distance between their descriptor distributions. Parameters ---------- parcel_descriptors : dict of {label: ndarray} Mapping from parcel label to descriptor array. Each value is shape (N_parcel, T) or (N_parcel,). method : str Distance method (see :func:`descriptor_distance`). **kwargs Extra args for the distance function. Returns ------- matrix : ndarray, shape (R, R) Symmetric distance matrix. labels : list Ordered parcel labels corresponding to matrix rows/columns. Examples -------- >>> parcels = sb.io.apply_parcellation(verts, faces, labels) >>> descs = {} >>> for lab, (v, f) in parcels.items(): ... mesh = BrainMesh(v, f) ... decomp = mesh.decompose(k=30) ... descs[lab] = compute_hks(decomp, n_times=20) >>> C, labs = build_geometric_connectome(descs, method="wasserstein") """ labels = sorted(parcel_descriptors.keys()) R = len(labels) matrix = np.zeros((R, R), dtype=np.float64) total_pairs = R * (R - 1) // 2 with progress_simple("Geometric connectome", total=total_pairs) as tick: for i in range(R): for j in range(i + 1, R): d = descriptor_distance( parcel_descriptors[labels[i]], parcel_descriptors[labels[j]], method=method, **kwargs, ) matrix[i, j] = d matrix[j, i] = d tick(1) logger.info( "Geometric connectome: %d × %d (method=%s)", R, R, method, ) return matrix, labels
[docs] def aggregate_to_networks( connectome: DistanceMatrix, parcel_labels: list, network_assignments: dict, *, aggregation: Literal["mean", "median"] = "mean", ) -> tuple[np.ndarray, list]: """Aggregate a parcel-level connectome to network level. Parameters ---------- connectome : ndarray, shape (R, R) Parcel-level distance matrix. parcel_labels : list Parcel labels (from :func:`build_geometric_connectome`). network_assignments : dict of {parcel_label: network_name} Mapping from each parcel to its canonical network. aggregation : str ``"mean"`` or ``"median"`` within each block. Returns ------- network_matrix : ndarray, shape (K, K) network_names : list of str """ networks = sorted(set(network_assignments.values())) K = len(networks) net_idx = {name: i for i, name in enumerate(networks)} # Map parcels to network indices. parcel_to_net = [] for lab in parcel_labels: net_name = network_assignments.get(lab) if net_name is None: parcel_to_net.append(-1) else: parcel_to_net.append(net_idx[net_name]) parcel_to_net = np.array(parcel_to_net) network_matrix = np.zeros((K, K), dtype=np.float64) agg_func = np.mean if aggregation == "mean" else np.median for i in range(K): for j in range(i, K): mask_i = parcel_to_net == i mask_j = parcel_to_net == j block = connectome[np.ix_(mask_i, mask_j)] if i == j: # Intra-network: exclude diagonal. vals = block[np.triu_indices_from(block, k=1)] else: vals = block.ravel() if len(vals) > 0: network_matrix[i, j] = float(agg_func(vals)) network_matrix[j, i] = network_matrix[i, j] logger.info("Network matrix: %d × %d", K, K) return network_matrix, networks
# ====================================================================== __all__: list[str] = [ "aggregate_to_networks", # Point-to-point "biharmonic_distance", # Geometric connectome "build_geometric_connectome", "commute_time_distance", # Descriptor distributions "descriptor_distance", "diffusion_distance", "diffusion_distance_multiscale", "shapedna_distance", # Shape-to-shape "wesd", "wesd_matrix", ]