"""Clustering visualisation — 3D mesh renders and 2D statistical plots.
Provides publication-quality figures for every output of
:mod:`spectralbrain.statistics.clustering`: spatial cluster maps,
GNMF components, persistence landscapes, temporal profiles, method
comparisons, Bayesian confirmation diagnostics, and fused-descriptor
panels.
Figure types
------------
**3D mesh renders (vedo)**
1. Cluster map — mesh coloured by integer labels, 3-pose panel
2. Cluster boundaries — wireframe with coloured boundary edges
3. Multi-method comparison — side-by-side cluster maps
4. GNMF spatial components — one panel per W column
5. Soft membership — mesh coloured by posterior probability
6. Exploded clusters — spatially separated cluster fragments
7. HKS + clusters progression — scalar + labels across t
8. Persistence basins — mesh coloured by persistence-based partition
9. Fusion panel — HKS / WKS / Fused side by side
**2D statistical plots (matplotlib)**
10. Cluster HKS time-profiles — mean ± SEM per cluster
11. Silhouette diagram — per-sample silhouette ordered by cluster
12. Cluster quality comparison — bar chart across methods
13. Method agreement heatmap — ARI / NMI matrix
14. Persistence diagram — birth vs death scatter
15. GNMF temporal factors — F matrix as line profiles
16. Bayesian confirmation — posterior probabilities + credible intervals
17. Cluster size distribution — bar chart
18. UMAP / PCA scatter — embedding coloured by clusters
19. Co-clustering checkerboard — vertex × time block structure
Architecture
------------
* **vedo** for all 3D renders (offscreen VTK → PNG).
* **matplotlib** for all 2D plots (publication style via graphics.py).
* 3D functions return ``(Path, dict)`` — PNG path + metadata.
* 2D functions return ``(Figure, Axes)`` for customisation.
* Every function accepts ``save`` for auto-export.
"""
from __future__ import annotations
import os
import tempfile
from pathlib import Path
from typing import Any, Literal
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import numpy as np
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from spectralbrain.runtime import PathLike, get_logger
logger = get_logger(__name__)
# ──────────────────────────────────────────────────────────────────────
# Constants
# ──────────────────────────────────────────────────────────────────────
_DEFAULT_SIZE: tuple[int, int] = (1600, 1200)
_DEFAULT_SCALE: int = 2
_DEFAULT_BG: str = "white"
DPI: int = 600
# Qualitative palette for cluster labels — optimised for
# colorblind safety (Paul Tol's muted scheme + extensions)
CLUSTER_COLORS: list[str] = [
"#4477AA", # blue
"#EE6677", # rose
"#228833", # green
"#CCBB44", # sand
"#66CCEE", # cyan
"#AA3377", # purple
"#EE8866", # orange
"#44AA99", # teal
"#332288", # indigo
"#CC6677", # wine
"#882255", # plum
"#117733", # forest
"#999933", # olive
"#DDCC77", # wheat
]
# Standard 3-pose views for brain structures
VIEWS_3POSE: list[str] = ["left_lateral", "anterior", "superior"]
# Camera presets — identical to geometry/meshes.py for consistency
CAMERA_PRESETS: dict[str, dict[str, Any]] = {
"anterior": {"azimuth": 0, "elevation": 0},
"posterior": {"azimuth": 180, "elevation": 0},
"left_lateral": {"azimuth": -90, "elevation": 0},
"right_lateral": {"azimuth": 90, "elevation": 0},
"superior": {"azimuth": 0, "elevation": 90},
"inferior": {"azimuth": 0, "elevation": -90},
"left_medial": {"azimuth": 90, "elevation": 0},
"right_medial": {"azimuth": -90, "elevation": 0},
"oblique_left": {"azimuth": -45, "elevation": 30},
"oblique_right": {"azimuth": 45, "elevation": 30},
}
# ──────────────────────────────────────────────────────────────────────
# Lazy imports & helpers
# ──────────────────────────────────────────────────────────────────────
def _ensure_offscreen() -> None:
"""Set vedo to offscreen rendering mode."""
os.environ.setdefault("VTK_USE_OFFSCREEN", "1")
def _get_vedo():
"""Lazy-import vedo, raising ImportError if unavailable."""
_ensure_offscreen()
try:
import vedo
try:
vedo.start_xvfb()
except Exception:
pass
return vedo
except ImportError:
raise ImportError(
"vedo is required for 3D cluster visualization. Install with: pip install vedo"
)
def _build_vedo_mesh(vertices, faces, vedo_module):
"""Construct a vedo Mesh from numpy arrays."""
V = np.asarray(vertices, dtype=np.float64)
F = np.asarray(faces, dtype=np.int64)
cells = np.column_stack([np.full(F.shape[0], 3, dtype=np.int64), F])
mesh = vedo_module.Mesh([V, cells])
return mesh
def _save_screenshot(plotter, save, *, scale=_DEFAULT_SCALE):
"""Capture a vedo Plotter to PNG and close it."""
if save is None:
fd, save = tempfile.mkstemp(suffix=".png")
os.close(fd)
save = Path(save)
save.parent.mkdir(parents=True, exist_ok=True)
plotter.screenshot(str(save), scale=scale)
plotter.close()
logger.info("Saved cluster render → %s", save)
return save
def _cluster_cmap(n_clusters: int):
"""Build a ListedColormap from the cluster palette."""
colors = (
CLUSTER_COLORS[:n_clusters]
if n_clusters <= len(CLUSTER_COLORS)
else (CLUSTER_COLORS * ((n_clusters // len(CLUSTER_COLORS)) + 1))[:n_clusters]
)
return mcolors.ListedColormap(colors)
def _apply_style():
"""Apply SpectralBrain publication style."""
try:
import scienceplots # noqa: F401
plt.style.use(["science", "no-latex"])
except ImportError:
pass
plt.rcParams.update(
{
"savefig.dpi": DPI,
"figure.dpi": 150,
"axes.spines.top": False,
"axes.spines.right": False,
"legend.frameon": False,
"font.size": 9,
}
)
def _savefig(fig: Figure, save: PathLike | None) -> None:
"""Save matplotlib figure in PNG + PDF."""
if save is not None:
p = Path(save)
p.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(str(p), dpi=DPI, bbox_inches="tight", facecolor="white", transparent=False)
# also save PDF if extension is png
if p.suffix.lower() == ".png":
fig.savefig(str(p.with_suffix(".pdf")), bbox_inches="tight", facecolor="white")
logger.info("Saved figure → %s", p)
# ======================================================================
# §1 3D CLUSTER MAP — mesh coloured by labels, 3-pose panel
# ======================================================================
[docs]
def plot_cluster_map(
vertices: np.ndarray,
faces: np.ndarray,
labels: np.ndarray,
*,
views: list[str] | None = None,
noise_color: str = "lightgray",
lighting: str = "default",
show_scalarbar: bool = True,
title: str | None = None,
bg: str = _DEFAULT_BG,
size: tuple[int, int] | None = None,
scale: int = _DEFAULT_SCALE,
save: PathLike | None = None,
) -> tuple[Path, dict[str, Any]]:
"""Render mesh coloured by cluster labels in a multi-view panel.
Each cluster gets a distinct colour from the colorblind-safe
palette. Noise vertices (label = -1) are rendered in grey.
Parameters
----------
vertices : (V, 3) array
faces : (F, 3) array
labels : (V,) int array
Cluster labels. -1 = noise / unassigned.
views : list of str or None
Camera preset names. None → 3-pose (lateral, anterior, superior).
noise_color : str
Colour for noise vertices.
lighting : str
show_scalarbar : bool
title : str or None
bg, size, scale, save
Standard render parameters.
Returns
-------
(Path, dict)
PNG path and metadata.
"""
vedo = _get_vedo()
labels = np.asarray(labels, dtype=np.int64)
if views is None:
views = VIEWS_3POSE
n_views = len(views)
if size is None:
size = (600 * n_views, 600)
# --- build RGBA per vertex ---
unique_labels = sorted(set(labels[labels >= 0]))
n_clusters = len(unique_labels)
cmap = _cluster_cmap(n_clusters)
label_to_idx = {lab: i for i, lab in enumerate(unique_labels)}
rgba = np.zeros((len(labels), 4), dtype=np.float64)
for i, lab in enumerate(labels):
if lab < 0:
rgba[i] = mcolors.to_rgba(noise_color)
else:
rgba[i] = cmap(label_to_idx[lab])
# vedo expects (V, 4) uint8 for vertex colours
rgba_u8 = (rgba * 255).astype(np.uint8)
plt = vedo.Plotter(
shape=(1, n_views),
offscreen=True,
size=size,
bg=bg,
)
for vi, view_name in enumerate(views):
mesh = _build_vedo_mesh(vertices, faces, vedo)
mesh.pointdata["ClusterRGBA"] = rgba_u8
mesh.pointdata.select("ClusterRGBA")
mesh.lighting(lighting)
preset = CAMERA_PRESETS.get(view_name, {})
plt.at(vi).show(
mesh,
title=view_name.replace("_", " ").title() if not title else title,
viewup="z",
zoom=1.1,
**{k: v for k, v in preset.items() if k in ("azimuth", "elevation")},
)
meta = {
"n_clusters": n_clusters,
"n_noise": int((labels < 0).sum()),
"views": views,
"cluster_colors": {
lab: CLUSTER_COLORS[i % len(CLUSTER_COLORS)] for i, lab in enumerate(unique_labels)
},
}
out = _save_screenshot(plt, save, scale=scale)
return out, meta
# ======================================================================
# §2 CLUSTER BOUNDARIES — mesh with highlighted boundary edges
# ======================================================================
[docs]
def plot_cluster_boundaries(
vertices: np.ndarray,
faces: np.ndarray,
labels: np.ndarray,
*,
mesh_color: str = "ivory",
mesh_alpha: float = 0.6,
boundary_width: float = 3.0,
views: list[str] | None = None,
lighting: str = "default",
bg: str = _DEFAULT_BG,
size: tuple[int, int] | None = None,
scale: int = _DEFAULT_SCALE,
save: PathLike | None = None,
) -> tuple[Path, dict[str, Any]]:
"""Render mesh with cluster boundaries as coloured lines.
Identifies edges where adjacent triangles have different cluster
labels and renders them as coloured tubes on a semi-transparent
surface.
Parameters
----------
vertices : (V, 3) array
faces : (F, 3) array
labels : (V,) int array
mesh_color : str
Base mesh colour.
mesh_alpha : float
Base mesh opacity.
boundary_width : float
Width of boundary lines.
views, lighting, bg, size, scale, save
Standard render parameters.
Returns
-------
(Path, dict)
"""
vedo = _get_vedo()
labels = np.asarray(labels, dtype=np.int64)
faces = np.asarray(faces, dtype=np.int64)
verts = np.asarray(vertices, dtype=np.float64)
if views is None:
views = VIEWS_3POSE
n_views = len(views)
if size is None:
size = (600 * n_views, 600)
# --- find boundary edges ---
edges_set = set()
boundary_edges = []
for f in faces:
for e in [(f[0], f[1]), (f[1], f[2]), (f[0], f[2])]:
e_sorted = tuple(sorted(e))
if e_sorted in edges_set:
continue
edges_set.add(e_sorted)
if labels[e[0]] != labels[e[1]] and labels[e[0]] >= 0 and labels[e[1]] >= 0:
boundary_edges.append(e_sorted)
n_boundary = len(boundary_edges)
logger.info("Found %d boundary edges between clusters.", n_boundary)
# --- build line segments ---
if boundary_edges:
pts_list = []
for e in boundary_edges:
pts_list.append([verts[e[0]], verts[e[1]]])
lines = vedo.Lines(pts_list, c="red", lw=boundary_width)
else:
lines = None
plt_obj = vedo.Plotter(
shape=(1, n_views),
offscreen=True,
size=size,
bg=bg,
)
for vi, view_name in enumerate(views):
mesh = _build_vedo_mesh(verts, faces, vedo)
mesh.color(mesh_color).alpha(mesh_alpha).lighting(lighting)
actors = [mesh]
if lines is not None:
actors.append(lines.clone())
preset = CAMERA_PRESETS.get(view_name, {})
plt_obj.at(vi).show(
*actors,
title=view_name.replace("_", " ").title(),
viewup="z",
zoom=1.1,
**{k: v for k, v in preset.items() if k in ("azimuth", "elevation")},
)
meta = {"n_boundary_edges": n_boundary, "views": views}
out = _save_screenshot(plt_obj, save, scale=scale)
return out, meta
# ======================================================================
# §3 MULTI-METHOD COMPARISON — side-by-side cluster maps
# ======================================================================
[docs]
def plot_method_comparison_3d(
vertices: np.ndarray,
faces: np.ndarray,
results: dict[str, np.ndarray],
*,
view: str = "left_lateral",
noise_color: str = "lightgray",
lighting: str = "default",
bg: str = _DEFAULT_BG,
size: tuple[int, int] | None = None,
scale: int = _DEFAULT_SCALE,
save: PathLike | None = None,
) -> tuple[Path, dict[str, Any]]:
"""Compare multiple clustering methods on the same mesh.
Each method gets one panel, all from the same camera angle.
Parameters
----------
vertices : (V, 3) array
faces : (F, 3) array
results : dict[str, ndarray]
Method name → (V,) label array.
view : str
Camera preset for all panels.
noise_color, lighting, bg, size, scale, save
Standard parameters.
Returns
-------
(Path, dict)
"""
vedo = _get_vedo()
methods = list(results.keys())
n = len(methods)
if size is None:
size = (600 * n, 600)
plt = vedo.Plotter(shape=(1, n), offscreen=True, size=size, bg=bg)
preset = CAMERA_PRESETS.get(view, {})
for i, method in enumerate(methods):
lab = np.asarray(results[method], dtype=np.int64)
unique = sorted(set(lab[lab >= 0]))
n_clust = len(unique)
cmap = _cluster_cmap(n_clust)
lab_to_idx = {l: j for j, l in enumerate(unique)}
rgba = np.zeros((len(lab), 4), dtype=np.float64)
for vi, l in enumerate(lab):
rgba[vi] = mcolors.to_rgba(noise_color) if l < 0 else cmap(lab_to_idx[l])
rgba_u8 = (rgba * 255).astype(np.uint8)
mesh = _build_vedo_mesh(vertices, faces, vedo)
mesh.pointdata["ClusterRGBA"] = rgba_u8
mesh.pointdata.select("ClusterRGBA")
mesh.lighting(lighting)
plt.at(i).show(
mesh,
title=f"{method} (k={n_clust})",
viewup="z",
zoom=1.1,
**{k: v for k, v in preset.items() if k in ("azimuth", "elevation")},
)
meta = {"methods": methods, "view": view}
out = _save_screenshot(plt, save, scale=scale)
return out, meta
# ======================================================================
# §4 GNMF SPATIAL COMPONENTS — one panel per W column
# ======================================================================
[docs]
def plot_gnmf_components(
vertices: np.ndarray,
faces: np.ndarray,
W: np.ndarray,
*,
cmap: str = "inferno",
max_components: int = 8,
view: str = "left_lateral",
lighting: str = "default",
bg: str = _DEFAULT_BG,
size: tuple[int, int] | None = None,
scale: int = _DEFAULT_SCALE,
save: PathLike | None = None,
) -> tuple[Path, dict[str, Any]]:
"""Render GNMF spatial factor columns as separate mesh panels.
Each column of W represents the soft membership weight for one
spatial component. Displayed as scalar overlays on the mesh.
Parameters
----------
vertices : (V, 3) array
faces : (F, 3) array
W : (V, K) array
Spatial factors from cluster_gnmf.
cmap : str
max_components : int
Show at most this many components.
view, lighting, bg, size, scale, save
Standard parameters.
Returns
-------
(Path, dict)
"""
vedo = _get_vedo()
W = np.asarray(W, dtype=np.float64)
K = min(W.shape[1], max_components)
n_cols = min(K, 4)
n_rows = (K + n_cols - 1) // n_cols
if size is None:
size = (500 * n_cols, 500 * n_rows)
plt = vedo.Plotter(
shape=(n_rows, n_cols),
offscreen=True,
size=size,
bg=bg,
)
preset = CAMERA_PRESETS.get(view, {})
for k in range(K):
row, col = divmod(k, n_cols)
mesh = _build_vedo_mesh(vertices, faces, vedo)
scalars = W[:, k]
vmin = float(np.nanpercentile(scalars, 1))
vmax = float(np.nanpercentile(scalars, 99))
mesh.pointdata[f"W_{k}"] = scalars
mesh.cmap(cmap, f"W_{k}", vmin=vmin, vmax=vmax)
mesh.add_scalarbar(title=f"Component {k}")
mesh.lighting(lighting)
plt.at(row * n_cols + col).show(
mesh,
title=f"W[:, {k}]",
viewup="z",
zoom=1.1,
**{kk: v for kk, v in preset.items() if kk in ("azimuth", "elevation")},
)
meta = {"n_components": K, "view": view}
out = _save_screenshot(plt, save, scale=scale)
return out, meta
# ======================================================================
# §5 SOFT MEMBERSHIP — mesh coloured by probability
# ======================================================================
[docs]
def plot_soft_membership(
vertices: np.ndarray,
faces: np.ndarray,
probabilities: np.ndarray,
cluster_idx: int = 0,
*,
cmap: str = "YlOrRd",
views: list[str] | None = None,
lighting: str = "default",
bg: str = _DEFAULT_BG,
size: tuple[int, int] | None = None,
scale: int = _DEFAULT_SCALE,
save: PathLike | None = None,
) -> tuple[Path, dict[str, Any]]:
"""Render soft membership probability for one cluster.
Parameters
----------
vertices : (V, 3) array
faces : (F, 3) array
probabilities : (V, K) array
Soft membership matrix.
cluster_idx : int
Which cluster's probability to display.
cmap, views, lighting, bg, size, scale, save
Standard parameters.
Returns
-------
(Path, dict)
"""
vedo = _get_vedo()
P = np.asarray(probabilities, dtype=np.float64)
scalars = P[:, cluster_idx]
if views is None:
views = VIEWS_3POSE
n_views = len(views)
if size is None:
size = (600 * n_views, 600)
plt = vedo.Plotter(shape=(1, n_views), offscreen=True, size=size, bg=bg)
for vi, view_name in enumerate(views):
mesh = _build_vedo_mesh(vertices, faces, vedo)
mesh.pointdata["P"] = scalars
mesh.cmap(cmap, "P", vmin=0.0, vmax=1.0)
mesh.add_scalarbar(title=f"P(cluster={cluster_idx})")
mesh.lighting(lighting)
preset = CAMERA_PRESETS.get(view_name, {})
plt.at(vi).show(
mesh,
title=view_name.replace("_", " ").title(),
viewup="z",
zoom=1.1,
**{k: v for k, v in preset.items() if k in ("azimuth", "elevation")},
)
meta = {"cluster_idx": cluster_idx}
out = _save_screenshot(plt, save, scale=scale)
return out, meta
# ======================================================================
# §6 EXPLODED CLUSTERS — spatially separated fragments
# ======================================================================
[docs]
def plot_cluster_exploded(
vertices: np.ndarray,
faces: np.ndarray,
labels: np.ndarray,
*,
explosion_factor: float = 1.5,
view: str = "oblique_left",
lighting: str = "default",
bg: str = _DEFAULT_BG,
size: tuple[int, int] = (1600, 1200),
scale: int = _DEFAULT_SCALE,
save: PathLike | None = None,
) -> tuple[Path, dict[str, Any]]:
"""Exploded view — each cluster is displaced outward from centroid.
Useful for inspecting cluster topology on convoluted structures
like the hippocampus where clusters may overlap visually.
Parameters
----------
vertices : (V, 3) array
faces : (F, 3) array
labels : (V,) int array
explosion_factor : float
Distance multiplier for displacement. 1.0 = in place.
view, lighting, bg, size, scale, save
Standard parameters.
Returns
-------
(Path, dict)
"""
vedo = _get_vedo()
labels = np.asarray(labels, dtype=np.int64)
verts = np.asarray(vertices, dtype=np.float64)
fcs = np.asarray(faces, dtype=np.int64)
unique_labels = sorted(set(labels[labels >= 0]))
n_clusters = len(unique_labels)
_cluster_cmap(n_clusters)
lab_to_idx = {l: i for i, l in enumerate(unique_labels)}
# global centroid
global_center = verts.mean(axis=0)
actors = []
for lab in unique_labels:
mask = labels == lab
vert_idx = np.where(mask)[0]
# remap vertex indices for the sub-mesh
idx_map = {old: new for new, old in enumerate(vert_idx)}
sub_faces = []
for f in fcs:
if mask[f[0]] and mask[f[1]] and mask[f[2]]:
sub_faces.append([idx_map[f[0]], idx_map[f[1]], idx_map[f[2]]])
if not sub_faces:
continue
sub_verts = verts[vert_idx].copy()
sub_faces_arr = np.array(sub_faces, dtype=np.int64)
# displacement vector
cluster_center = sub_verts.mean(axis=0)
direction = cluster_center - global_center
norm = np.linalg.norm(direction)
if norm > 1e-8:
direction /= norm
displacement = direction * norm * explosion_factor
sub_verts += displacement
mesh = _build_vedo_mesh(sub_verts, sub_faces_arr, vedo)
color = CLUSTER_COLORS[lab_to_idx[lab] % len(CLUSTER_COLORS)]
mesh.color(color).lighting(lighting)
actors.append(mesh)
preset = CAMERA_PRESETS.get(view, {})
plt = vedo.Plotter(offscreen=True, size=size, bg=bg)
plt.show(
*actors,
title="Exploded Cluster View",
viewup="z",
zoom=0.9,
**{k: v for k, v in preset.items() if k in ("azimuth", "elevation")},
)
meta = {"n_clusters": n_clusters, "explosion_factor": explosion_factor}
out = _save_screenshot(plt, save, scale=scale)
return out, meta
# ======================================================================
# §7 HKS + CLUSTERS PROGRESSION — scalar + labels across t
# ======================================================================
[docs]
def plot_hks_cluster_progression(
vertices: np.ndarray,
faces: np.ndarray,
H: np.ndarray,
labels: np.ndarray,
t_indices: list[int] | None = None,
*,
n_panels: int = 4,
view: str = "left_lateral",
cmap_hks: str = "inferno",
lighting: str = "default",
bg: str = _DEFAULT_BG,
size: tuple[int, int] | None = None,
scale: int = _DEFAULT_SCALE,
save: PathLike | None = None,
) -> tuple[Path, dict[str, Any]]:
"""Show HKS scalar at selected time-scales alongside cluster map.
Top row: HKS at different t. Bottom row: cluster labels.
Parameters
----------
vertices : (V, 3) array
faces : (F, 3) array
H : (V, T) array
HKS matrix.
labels : (V,) int array
t_indices : list of int or None
Column indices into H to show. None → linearly spaced.
n_panels : int
Number of time-scale panels.
view, cmap_hks, lighting, bg, size, scale, save
Standard parameters.
Returns
-------
(Path, dict)
"""
vedo = _get_vedo()
H = np.asarray(H, dtype=np.float64)
labels = np.asarray(labels, dtype=np.int64)
T = H.shape[1]
if t_indices is None:
t_indices = np.linspace(0, T - 1, n_panels, dtype=int).tolist()
n_panels = len(t_indices)
if size is None:
size = (500 * n_panels, 1000)
# 2 rows: top = HKS, bottom = clusters
plt = vedo.Plotter(
shape=(2, n_panels),
offscreen=True,
size=size,
bg=bg,
)
preset = CAMERA_PRESETS.get(view, {})
unique = sorted(set(labels[labels >= 0]))
n_clusters = len(unique)
ccmap = _cluster_cmap(n_clusters)
lab_to_idx = {l: i for i, l in enumerate(unique)}
# cluster RGBA
rgba = np.zeros((len(labels), 4), dtype=np.float64)
for i, l in enumerate(labels):
rgba[i] = mcolors.to_rgba("lightgray") if l < 0 else ccmap(lab_to_idx[l])
rgba_u8 = (rgba * 255).astype(np.uint8)
for pi, ti in enumerate(t_indices):
# top row: HKS
mesh_hks = _build_vedo_mesh(vertices, faces, vedo)
sc = H[:, ti]
v0 = float(np.nanpercentile(sc, 1))
v1 = float(np.nanpercentile(sc, 99))
mesh_hks.pointdata["HKS"] = sc
mesh_hks.cmap(cmap_hks, "HKS", vmin=v0, vmax=v1)
mesh_hks.add_scalarbar(title=f"t={ti}")
mesh_hks.lighting(lighting)
plt.at(0 * n_panels + pi).show(
mesh_hks,
title=f"HKS t[{ti}]",
viewup="z",
zoom=1.1,
**{k: v for k, v in preset.items() if k in ("azimuth", "elevation")},
)
# bottom row: clusters
mesh_cl = _build_vedo_mesh(vertices, faces, vedo)
mesh_cl.pointdata["ClusterRGBA"] = rgba_u8
mesh_cl.pointdata.select("ClusterRGBA")
mesh_cl.lighting(lighting)
plt.at(1 * n_panels + pi).show(
mesh_cl,
title="Clusters",
viewup="z",
zoom=1.1,
**{k: v for k, v in preset.items() if k in ("azimuth", "elevation")},
)
meta = {"t_indices": t_indices, "n_panels": n_panels}
out = _save_screenshot(plt, save, scale=scale)
return out, meta
# ======================================================================
# §8 FUSION PANEL — HKS / WKS / Fused side by side
# ======================================================================
[docs]
def plot_fusion_panel(
vertices: np.ndarray,
faces: np.ndarray,
hks_scalar: np.ndarray,
wks_scalar: np.ndarray,
fused_scalar: np.ndarray,
*,
cmap_hks: str = "inferno",
cmap_wks: str = "cividis",
cmap_fused: str = "magma",
view: str = "left_lateral",
lighting: str = "default",
bg: str = _DEFAULT_BG,
size: tuple[int, int] = (1800, 600),
scale: int = _DEFAULT_SCALE,
save: PathLike | None = None,
) -> tuple[Path, dict[str, Any]]:
"""Side-by-side render of HKS, WKS, and fused descriptor.
Parameters
----------
vertices : (V, 3) array
faces : (F, 3) array
hks_scalar, wks_scalar, fused_scalar : (V,) arrays
Single-scale scalars for each descriptor type.
cmap_hks, cmap_wks, cmap_fused : str
view, lighting, bg, size, scale, save
Returns
-------
(Path, dict)
"""
vedo = _get_vedo()
preset = CAMERA_PRESETS.get(view, {})
plt = vedo.Plotter(shape=(1, 3), offscreen=True, size=size, bg=bg)
for pi, (sc, name, cm) in enumerate(
[
(hks_scalar, "HKS", cmap_hks),
(wks_scalar, "WKS", cmap_wks),
(fused_scalar, "Fused", cmap_fused),
]
):
sc = np.asarray(sc, dtype=np.float64)
v0 = float(np.nanpercentile(sc, 1))
v1 = float(np.nanpercentile(sc, 99))
mesh = _build_vedo_mesh(vertices, faces, vedo)
mesh.pointdata[name] = sc
mesh.cmap(cm, name, vmin=v0, vmax=v1)
mesh.add_scalarbar(title=name)
mesh.lighting(lighting)
plt.at(pi).show(
mesh,
title=name,
viewup="z",
zoom=1.1,
**{k: v for k, v in preset.items() if k in ("azimuth", "elevation")},
)
meta = {"view": view}
out = _save_screenshot(plt, save, scale=scale)
return out, meta
# ======================================================================
# §10 CLUSTER HKS TIME-PROFILES — mean ± SEM per cluster
# ======================================================================
[docs]
def plot_cluster_profiles(
H: np.ndarray,
labels: np.ndarray,
t_values: np.ndarray | None = None,
*,
log_t: bool = True,
show_sem: bool = True,
title: str = "Cluster HKS Profiles",
xlabel: str = "Diffusion time t",
ylabel: str = "HKS(x, t)",
figsize: tuple[float, float] = (7, 4),
save: PathLike | None = None,
) -> tuple[Figure, Axes]:
"""Mean ± SEM HKS profiles per cluster.
Parameters
----------
H : (N, T) array
labels : (N,) int array
t_values : (T,) array or None
log_t : bool
Use log-scale on x-axis.
show_sem : bool
Show shaded SEM bands.
title, xlabel, ylabel, figsize, save
Returns
-------
(Figure, Axes)
"""
_apply_style()
H = np.asarray(H, dtype=np.float64)
labels = np.asarray(labels, dtype=np.int64)
T = H.shape[1]
if t_values is None:
t_values = np.arange(T, dtype=np.float64)
t_values = np.asarray(t_values, dtype=np.float64)
unique = sorted(set(labels[labels >= 0]))
fig, ax = plt.subplots(figsize=figsize)
for i, lab in enumerate(unique):
mask = labels == lab
cluster_h = H[mask]
mean = cluster_h.mean(axis=0)
color = CLUSTER_COLORS[i % len(CLUSTER_COLORS)]
if log_t:
ax.plot(t_values, mean, color=color, label=f"Cluster {lab}", linewidth=1.8)
else:
ax.plot(t_values, mean, color=color, label=f"Cluster {lab}", linewidth=1.8)
if show_sem and mask.sum() > 1:
sem = cluster_h.std(axis=0) / np.sqrt(mask.sum())
ax.fill_between(t_values, mean - sem, mean + sem, color=color, alpha=0.2)
if log_t:
ax.set_xscale("log")
ax.set_xlabel(xlabel)
ax.set_ylabel(ylabel)
ax.set_title(title)
ax.legend(fontsize=7, ncol=min(len(unique), 4))
fig.tight_layout()
_savefig(fig, save)
return fig, ax
# ======================================================================
# §11 SILHOUETTE DIAGRAM
# ======================================================================
[docs]
def plot_silhouette_diagram(
H: np.ndarray,
labels: np.ndarray,
*,
metric: str = "euclidean",
title: str = "Silhouette Diagram",
figsize: tuple[float, float] = (6, 5),
save: PathLike | None = None,
) -> tuple[Figure, Axes]:
"""Per-sample silhouette plot ordered by cluster.
Parameters
----------
H : (N, T) array or (N, N) precomputed distance
labels : (N,) int array
metric : str
title, figsize, save
Returns
-------
(Figure, Axes)
"""
from sklearn.metrics import silhouette_samples, silhouette_score
_apply_style()
labels = np.asarray(labels, dtype=np.int64)
valid = labels >= 0
H_v = H[valid]
lab_v = labels[valid]
sil_vals = silhouette_samples(H_v, lab_v, metric=metric)
avg_sil = silhouette_score(H_v, lab_v, metric=metric)
unique = sorted(set(lab_v))
fig, ax = plt.subplots(figsize=figsize)
y_lower = 0
for i, lab in enumerate(unique):
cluster_sil = np.sort(sil_vals[lab_v == lab])
cluster_size = len(cluster_sil)
y_upper = y_lower + cluster_size
color = CLUSTER_COLORS[i % len(CLUSTER_COLORS)]
ax.barh(
range(y_lower, y_upper),
cluster_sil,
height=1.0,
color=color,
edgecolor="none",
label=f"Cluster {lab}",
)
# label centroid
ax.text(-0.05, y_lower + 0.5 * cluster_size, str(lab), fontsize=8, va="center", ha="right")
y_lower = y_upper + 2 # gap between clusters
ax.axvline(avg_sil, color="red", linestyle="--", linewidth=1, label=f"Mean = {avg_sil:.3f}")
ax.set_xlabel("Silhouette coefficient")
ax.set_ylabel("Vertices (sorted by cluster)")
ax.set_title(title)
ax.set_yticks([])
ax.legend(fontsize=7, loc="lower right")
fig.tight_layout()
_savefig(fig, save)
return fig, ax
# ======================================================================
# §12 CLUSTER QUALITY COMPARISON — bar chart across methods
# ======================================================================
[docs]
def plot_quality_comparison(
quality_dict: dict[str, dict[str, float]],
*,
metrics: list[str] | None = None,
title: str = "Clustering Quality Comparison",
figsize: tuple[float, float] = (8, 4),
save: PathLike | None = None,
) -> tuple[Figure, Axes]:
"""Grouped bar chart comparing quality metrics across methods.
Parameters
----------
quality_dict : dict[str, dict[str, float]]
Outer key = method name, inner dict = metric → value.
Example: ``{"hdbscan": {"silhouette": 0.42}, ...}``
metrics : list of str or None
Which metrics to plot. None → all common metrics.
title, figsize, save
Returns
-------
(Figure, Axes)
"""
_apply_style()
methods = list(quality_dict.keys())
if metrics is None:
all_keys = set()
for v in quality_dict.values():
all_keys.update(v.keys())
metrics = sorted(all_keys)
n_methods = len(methods)
n_metrics = len(metrics)
x = np.arange(n_metrics)
width = 0.8 / n_methods
fig, ax = plt.subplots(figsize=figsize)
for i, method in enumerate(methods):
vals = [quality_dict[method].get(m, 0.0) for m in metrics]
color = CLUSTER_COLORS[i % len(CLUSTER_COLORS)]
ax.bar(x + i * width, vals, width, label=method, color=color)
ax.set_xticks(x + width * (n_methods - 1) / 2)
ax.set_xticklabels(metrics, rotation=30, ha="right", fontsize=8)
ax.set_ylabel("Score")
ax.set_title(title)
ax.legend(fontsize=7)
fig.tight_layout()
_savefig(fig, save)
return fig, ax
# ======================================================================
# §13 METHOD AGREEMENT HEATMAP — ARI / NMI matrix
# ======================================================================
[docs]
def plot_agreement_heatmap(
agreement_matrix: np.ndarray,
method_names: list[str],
*,
metric_name: str = "ARI",
cmap: str = "YlGnBu",
title: str = "Inter-Method Agreement",
figsize: tuple[float, float] = (6, 5),
save: PathLike | None = None,
) -> tuple[Figure, Axes]:
"""Heatmap of pairwise clustering agreement.
Parameters
----------
agreement_matrix : (M, M) array
Pairwise ARI or NMI scores.
method_names : list of str
metric_name : str
cmap, title, figsize, save
Returns
-------
(Figure, Axes)
"""
_apply_style()
fig, ax = plt.subplots(figsize=figsize)
im = ax.imshow(agreement_matrix, cmap=cmap, vmin=0, vmax=1)
ax.set_xticks(range(len(method_names)))
ax.set_yticks(range(len(method_names)))
ax.set_xticklabels(method_names, rotation=45, ha="right", fontsize=8)
ax.set_yticklabels(method_names, fontsize=8)
# annotate cells
for i in range(len(method_names)):
for j in range(len(method_names)):
ax.text(
j,
i,
f"{agreement_matrix[i, j]:.2f}",
ha="center",
va="center",
fontsize=8,
color="white" if agreement_matrix[i, j] > 0.5 else "black",
)
cbar = fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
cbar.set_label(metric_name, fontsize=9)
ax.set_title(title)
fig.tight_layout()
_savefig(fig, save)
return fig, ax
# ======================================================================
# §14 PERSISTENCE DIAGRAM — birth vs death scatter
# ======================================================================
[docs]
def plot_persistence_diagram(
diagram: np.ndarray,
*,
title: str = "Persistence Diagram",
figsize: tuple[float, float] = (5, 5),
save: PathLike | None = None,
) -> tuple[Figure, Axes]:
"""Birth-death scatter for persistence-based clustering.
Parameters
----------
diagram : (n_pairs, 2) array
Each row is (birth, death).
title, figsize, save
Returns
-------
(Figure, Axes)
"""
_apply_style()
diagram = np.asarray(diagram, dtype=np.float64)
births = diagram[:, 0]
deaths = diagram[:, 1]
persistence = deaths - births
fig, ax = plt.subplots(figsize=figsize)
# diagonal
lims = [min(births.min(), deaths.min()) - 0.1, max(births.max(), deaths.max()) + 0.1]
ax.plot(lims, lims, "k--", linewidth=0.5, alpha=0.5)
# colour by persistence
sc = ax.scatter(
births,
deaths,
c=persistence,
cmap="plasma",
s=30,
alpha=0.7,
edgecolors="k",
linewidths=0.3,
)
fig.colorbar(sc, ax=ax, label="Persistence")
ax.set_xlabel("Birth")
ax.set_ylabel("Death")
ax.set_title(title)
ax.set_xlim(lims)
ax.set_ylim(lims)
ax.set_aspect("equal")
fig.tight_layout()
_savefig(fig, save)
return fig, ax
# ======================================================================
# §15 GNMF TEMPORAL FACTORS — F matrix profiles
# ======================================================================
[docs]
def plot_gnmf_temporal_factors(
F: np.ndarray,
t_values: np.ndarray | None = None,
*,
log_t: bool = True,
title: str = "GNMF Temporal Factors",
figsize: tuple[float, float] = (7, 4),
save: PathLike | None = None,
) -> tuple[Figure, Axes]:
"""Plot rows of the GNMF F matrix as temporal profiles.
Each row of F is a canonical HKS-like curve for one component.
Parameters
----------
F : (K, T) array
t_values : (T,) array or None
log_t : bool
title, figsize, save
Returns
-------
(Figure, Axes)
"""
_apply_style()
F = np.asarray(F, dtype=np.float64)
K, T = F.shape
if t_values is None:
t_values = np.arange(T, dtype=np.float64)
fig, ax = plt.subplots(figsize=figsize)
for k in range(K):
color = CLUSTER_COLORS[k % len(CLUSTER_COLORS)]
ax.plot(t_values, F[k], color=color, linewidth=1.5, label=f"Component {k}")
if log_t:
ax.set_xscale("log")
ax.set_xlabel("Diffusion time t")
ax.set_ylabel("F(t)")
ax.set_title(title)
ax.legend(fontsize=7, ncol=min(K, 4))
fig.tight_layout()
_savefig(fig, save)
return fig, ax
# ======================================================================
# §16 BAYESIAN CONFIRMATION — posteriors + credible intervals
# ======================================================================
[docs]
def plot_bayesian_confirmation(
label_probabilities: np.ndarray,
credible_intervals: dict[int, dict[str, Any]],
agreement_ari: float,
*,
title: str = "Bayesian Cluster Confirmation",
figsize: tuple[float, float] = (10, 4),
save: PathLike | None = None,
) -> tuple[Figure, tuple[Axes, Axes]]:
"""Two-panel Bayesian confirmation diagnostic.
Left: stacked area of posterior membership probabilities.
Right: credible intervals (94% HDI) of cluster centroid norms.
Parameters
----------
label_probabilities : (N, K) array
credible_intervals : dict[int, dict]
Per-cluster HDI summaries from confirm_clusters_bayesian.
agreement_ari : float
title, figsize, save
Returns
-------
(Figure, (Axes, Axes))
"""
_apply_style()
P = np.asarray(label_probabilities, dtype=np.float64)
N, K = P.shape
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=figsize)
# --- Left: sorted probability distribution ---
# sort vertices by max-probability cluster for visual clarity
max_cluster = P.argmax(axis=1)
max_prob = P.max(axis=1)
order = np.lexsort((max_prob, max_cluster))
P_sorted = P[order]
bottom = np.zeros(N)
for k in range(K):
color = CLUSTER_COLORS[k % len(CLUSTER_COLORS)]
ax1.fill_between(
range(N), bottom, bottom + P_sorted[:, k], color=color, alpha=0.8, label=f"Cl {k}"
)
bottom += P_sorted[:, k]
ax1.set_xlabel("Vertices (sorted)")
ax1.set_ylabel("P(cluster)")
ax1.set_title(f"Posterior Membership (ARI={agreement_ari:.3f})")
ax1.set_xlim(0, N)
ax1.set_ylim(0, 1)
ax1.legend(fontsize=6, ncol=min(K, 4), loc="lower left")
# --- Right: credible intervals ---
clusters = sorted(credible_intervals.keys())
y_pos = np.arange(len(clusters))
means = []
lows = []
highs = []
for k in clusters:
ci = credible_intervals[k]
# norm of centroid mean as a summary scalar
m = np.linalg.norm(ci["mean"])
lo = np.linalg.norm(ci["hdi_3"])
hi = np.linalg.norm(ci["hdi_97"])
means.append(m)
lows.append(lo)
highs.append(hi)
means = np.array(means)
lows = np.array(lows)
highs = np.array(highs)
for i, _k in enumerate(clusters):
color = CLUSTER_COLORS[i % len(CLUSTER_COLORS)]
ax2.plot(
[lows[i], highs[i]],
[y_pos[i], y_pos[i]],
color=color,
linewidth=2.5,
solid_capstyle="round",
)
ax2.plot(means[i], y_pos[i], "o", color=color, markersize=6)
ax2.set_yticks(y_pos)
ax2.set_yticklabels([f"Cluster {k}" for k in clusters], fontsize=8)
ax2.set_xlabel("‖μ_k‖ (centroid norm)")
ax2.set_title("94% Credible Intervals")
ax2.invert_yaxis()
fig.suptitle(title, fontsize=11, y=1.02)
fig.tight_layout()
_savefig(fig, save)
return fig, (ax1, ax2)
# ======================================================================
# §17 CLUSTER SIZE DISTRIBUTION
# ======================================================================
[docs]
def plot_cluster_sizes(
labels: np.ndarray,
*,
title: str = "Cluster Size Distribution",
figsize: tuple[float, float] = (6, 3.5),
save: PathLike | None = None,
) -> tuple[Figure, Axes]:
"""Bar chart of cluster sizes with noise count.
Parameters
----------
labels : (N,) int array
title, figsize, save
Returns
-------
(Figure, Axes)
"""
_apply_style()
labels = np.asarray(labels, dtype=np.int64)
unique, counts = np.unique(labels, return_counts=True)
fig, ax = plt.subplots(figsize=figsize)
for i, (lab, cnt) in enumerate(zip(unique, counts)):
if lab < 0:
color = "lightgray"
else:
color = CLUSTER_COLORS[i % len(CLUSTER_COLORS)]
ax.bar(i, cnt, color=color, edgecolor="k", linewidth=0.3)
ax.text(i, cnt + max(counts) * 0.01, str(cnt), ha="center", va="bottom", fontsize=7)
ax.set_xticks(range(len(unique)))
ax.set_xticklabels(
["Noise" if l < 0 else str(l) for l in unique],
fontsize=8,
)
ax.set_ylabel("Number of vertices")
ax.set_title(title)
fig.tight_layout()
_savefig(fig, save)
return fig, ax
# ======================================================================
# §18 UMAP / PCA SCATTER — embedding coloured by clusters
# ======================================================================
[docs]
def plot_cluster_scatter(
embedding: np.ndarray,
labels: np.ndarray,
*,
method_name: str = "UMAP",
noise_color: str = "lightgray",
noise_alpha: float = 0.3,
point_size: float = 3.0,
title: str | None = None,
figsize: tuple[float, float] = (6, 5),
save: PathLike | None = None,
) -> tuple[Figure, Axes]:
"""2D scatter of dimensionality-reduced embedding coloured by cluster.
Parameters
----------
embedding : (N, 2) array
labels : (N,) int array
method_name : str
For axis labels (e.g., "UMAP", "PCA", "t-SNE").
noise_color, noise_alpha, point_size
title, figsize, save
Returns
-------
(Figure, Axes)
"""
_apply_style()
embedding = np.asarray(embedding, dtype=np.float64)
labels = np.asarray(labels, dtype=np.int64)
fig, ax = plt.subplots(figsize=figsize)
# noise first (background)
noise = labels < 0
if noise.any():
ax.scatter(
embedding[noise, 0],
embedding[noise, 1],
c=noise_color,
alpha=noise_alpha,
s=point_size,
label="Noise",
rasterized=True,
)
unique = sorted(set(labels[labels >= 0]))
for i, lab in enumerate(unique):
mask = labels == lab
color = CLUSTER_COLORS[i % len(CLUSTER_COLORS)]
ax.scatter(
embedding[mask, 0],
embedding[mask, 1],
c=color,
s=point_size,
alpha=0.7,
label=f"Cluster {lab}",
rasterized=True,
)
ax.set_xlabel(f"{method_name} 1")
ax.set_ylabel(f"{method_name} 2")
ax.set_title(title or f"Cluster Map in {method_name} Space")
ax.legend(fontsize=7, markerscale=3, ncol=min(len(unique) + 1, 4))
fig.tight_layout()
_savefig(fig, save)
return fig, ax
# ======================================================================
# §19 CO-CLUSTERING CHECKERBOARD — vertex × time block structure
# ======================================================================
[docs]
def plot_coclustering_heatmap(
H: np.ndarray,
row_labels: np.ndarray,
col_labels: np.ndarray,
*,
cmap: str = "viridis",
title: str = "Co-Clustering Structure",
figsize: tuple[float, float] = (8, 6),
save: PathLike | None = None,
) -> tuple[Figure, Axes]:
"""Reordered heatmap revealing vertex × time co-cluster blocks.
Sorts rows and columns by their cluster labels so that the
checkerboard block structure of the co-clustering is visible.
Parameters
----------
H : (N, T) array
row_labels : (N,) int array
Vertex cluster labels.
col_labels : (T,) int array
Time/scale cluster labels.
cmap, title, figsize, save
Returns
-------
(Figure, Axes)
"""
_apply_style()
H = np.asarray(H, dtype=np.float64)
row_labels = np.asarray(row_labels, dtype=np.int64)
col_labels = np.asarray(col_labels, dtype=np.int64)
# sort rows and columns by cluster label
row_order = np.argsort(row_labels)
col_order = np.argsort(col_labels)
H_sorted = H[row_order][:, col_order]
# apply log for visual dynamic range
H_vis = np.log(H_sorted + 1e-12)
fig, ax = plt.subplots(figsize=figsize)
im = ax.imshow(H_vis, aspect="auto", cmap=cmap, interpolation="none")
# draw block boundaries
row_sorted_labels = row_labels[row_order]
col_sorted_labels = col_labels[col_order]
for i in range(1, len(row_sorted_labels)):
if row_sorted_labels[i] != row_sorted_labels[i - 1]:
ax.axhline(i - 0.5, color="white", linewidth=0.5)
for j in range(1, len(col_sorted_labels)):
if col_sorted_labels[j] != col_sorted_labels[j - 1]:
ax.axvline(j - 0.5, color="white", linewidth=0.5)
fig.colorbar(im, ax=ax, label="log(descriptor value)", shrink=0.8)
ax.set_xlabel("Time/Energy scale (sorted by cluster)")
ax.set_ylabel("Vertices (sorted by cluster)")
ax.set_title(title)
fig.tight_layout()
_savefig(fig, save)
return fig, ax
# ======================================================================
# §20 SUMMARY PANEL — comprehensive overview figure
# ======================================================================
[docs]
def plot_cluster_summary(
vertices: np.ndarray,
faces: np.ndarray,
H: np.ndarray,
labels: np.ndarray,
t_values: np.ndarray | None = None,
*,
method_name: str = "GNMF",
figsize: tuple[float, float] = (14, 10),
save: PathLike | None = None,
) -> tuple[Figure, np.ndarray]:
"""Comprehensive 2×3 summary panel for a clustering result.
Layout:
[0,0] 3D cluster map (embedded PNG)
[0,1] Cluster sizes bar chart
[0,2] UMAP scatter coloured by cluster
[1,0] HKS profiles per cluster
[1,1] Silhouette diagram
[1,2] Persistence diagram (if available)
Parameters
----------
vertices : (V, 3) array
faces : (F, 3) array
H : (V, T) array
labels : (V,) int array
t_values : (T,) or None
method_name : str
figsize, save
Returns
-------
(Figure, ndarray of Axes)
"""
_apply_style()
import matplotlib.image as mpimg
fig, axes = plt.subplots(2, 3, figsize=figsize)
# --- [0,0] 3D render as embedded image ---
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tf:
tmp_path = tf.name
try:
plot_cluster_map(vertices, faces, labels, views=["left_lateral"], save=tmp_path)
img = mpimg.imread(tmp_path)
axes[0, 0].imshow(img)
axes[0, 0].set_title(f"{method_name} Cluster Map")
axes[0, 0].axis("off")
except Exception as e:
axes[0, 0].text(
0.5,
0.5,
f"3D render failed:\n{e}",
ha="center",
va="center",
fontsize=8,
transform=axes[0, 0].transAxes,
)
axes[0, 0].axis("off")
finally:
try:
os.unlink(tmp_path)
except OSError:
pass
# --- [0,1] Cluster sizes ---
unique, counts = np.unique(labels[labels >= 0], return_counts=True)
colors = [CLUSTER_COLORS[i % len(CLUSTER_COLORS)] for i in range(len(unique))]
axes[0, 1].bar(range(len(unique)), counts, color=colors, edgecolor="k", linewidth=0.3)
axes[0, 1].set_xticks(range(len(unique)))
axes[0, 1].set_xticklabels([str(u) for u in unique], fontsize=7)
axes[0, 1].set_ylabel("n vertices")
axes[0, 1].set_title("Cluster Sizes")
# --- [0,2] UMAP scatter ---
try:
import umap as umap_mod
X = np.log(H + 1e-12)
embedding = umap_mod.UMAP(n_components=2, random_state=42).fit_transform(X)
for i, lab in enumerate(sorted(set(labels[labels >= 0]))):
mask = labels == lab
color = CLUSTER_COLORS[i % len(CLUSTER_COLORS)]
axes[0, 2].scatter(
embedding[mask, 0], embedding[mask, 1], c=color, s=2, alpha=0.5, rasterized=True
)
noise = labels < 0
if noise.any():
axes[0, 2].scatter(
embedding[noise, 0],
embedding[noise, 1],
c="lightgray",
s=1,
alpha=0.2,
rasterized=True,
)
axes[0, 2].set_title("UMAP Embedding")
axes[0, 2].set_xlabel("UMAP 1")
axes[0, 2].set_ylabel("UMAP 2")
except ImportError:
from sklearn.decomposition import PCA
X = np.log(H + 1e-12)
embedding = PCA(n_components=2).fit_transform(X)
for i, lab in enumerate(sorted(set(labels[labels >= 0]))):
mask = labels == lab
color = CLUSTER_COLORS[i % len(CLUSTER_COLORS)]
axes[0, 2].scatter(
embedding[mask, 0], embedding[mask, 1], c=color, s=2, alpha=0.5, rasterized=True
)
axes[0, 2].set_title("PCA Embedding")
axes[0, 2].set_xlabel("PC 1")
axes[0, 2].set_ylabel("PC 2")
# --- [1,0] HKS profiles per cluster ---
if t_values is None:
t_vals = np.arange(H.shape[1], dtype=np.float64)
else:
t_vals = np.asarray(t_values)
for i, lab in enumerate(sorted(set(labels[labels >= 0]))):
mask = labels == lab
mean_h = H[mask].mean(axis=0)
color = CLUSTER_COLORS[i % len(CLUSTER_COLORS)]
axes[1, 0].plot(t_vals, mean_h, color=color, linewidth=1.2, label=f"Cl {lab}")
if mask.sum() > 1:
sem = H[mask].std(axis=0) / np.sqrt(mask.sum())
axes[1, 0].fill_between(t_vals, mean_h - sem, mean_h + sem, color=color, alpha=0.15)
axes[1, 0].set_xscale("log")
axes[1, 0].set_xlabel("t")
axes[1, 0].set_ylabel("HKS")
axes[1, 0].set_title("Cluster HKS Profiles")
axes[1, 0].legend(fontsize=6, ncol=min(len(unique), 4))
# --- [1,1] Silhouette ---
try:
from sklearn.metrics import silhouette_samples
valid = labels >= 0
sil = silhouette_samples(H[valid], labels[valid])
y_lower = 0
for i, lab in enumerate(sorted(set(labels[labels >= 0]))):
cl_sil = np.sort(sil[labels[valid] == lab])
y_upper = y_lower + len(cl_sil)
color = CLUSTER_COLORS[i % len(CLUSTER_COLORS)]
axes[1, 1].barh(
range(y_lower, y_upper), cl_sil, height=1.0, color=color, edgecolor="none"
)
y_lower = y_upper + 2
axes[1, 1].set_yticks([])
axes[1, 1].set_xlabel("Silhouette")
axes[1, 1].set_title("Silhouette Diagram")
except Exception:
axes[1, 1].text(
0.5,
0.5,
"Could not compute silhouette",
ha="center",
va="center",
fontsize=8,
transform=axes[1, 1].transAxes,
)
# --- [1,2] Empty or placeholder ---
axes[1, 2].text(
0.5,
0.5,
f"Method: {method_name}\nk = {len(unique)}",
ha="center",
va="center",
fontsize=12,
transform=axes[1, 2].transAxes,
bbox=dict(boxstyle="round", facecolor="wheat", alpha=0.5),
)
axes[1, 2].axis("off")
fig.suptitle(f"Clustering Summary — {method_name}", fontsize=13, y=1.01)
fig.tight_layout()
_savefig(fig, save)
return fig, axes
# ======================================================================
# §21 SPATIO-TEMPORAL FIELD — small multiples on unfolded coordinates
# ======================================================================
[docs]
def plot_spatiotemporal_field(
unfolded_coords: np.ndarray,
faces: np.ndarray,
H: np.ndarray,
t_values: np.ndarray | None = None,
*,
n_panels: int = 8,
t_indices: list[int] | None = None,
cmap: str = "magma",
log_norm: bool = True,
vmin: float | None = None,
vmax: float | None = None,
subfield_labels: np.ndarray | None = None,
boundary_color: str = "k",
boundary_width: float = 0.6,
descriptor_name: str = "HKS",
xlabel: str = "AP coordinate",
ylabel: str = "PD coordinate",
n_cols: int = 4,
figsize: tuple[float, float] | None = None,
save: PathLike | None = None,
) -> tuple[Figure, np.ndarray]:
"""Small-multiples grid of a descriptor field on unfolded surface.
Renders per-vertex spectral descriptor values (HKS, WKS) on the
2D unfolded coordinate system (e.g., HippUnfold's AP × PD sheet)
at log-spaced diffusion times, producing a publication-ready
panel that shows the local-to-global progression of HKS.
Each panel is rendered with ``tripcolor`` directly on the mesh
triangulation — no interpolation to a regular grid is needed,
preserving vertex-exact values.
Parameters
----------
unfolded_coords : (V, 2) array
Per-vertex 2D coordinates in unfolded space.
Column 0 = AP (or u), column 1 = PD (or v).
faces : (F, 3) array
Triangle indices (same mesh topology as the folded surface).
H : (V, T) array
Per-vertex descriptor matrix across T scales.
t_values : (T,) array or None
Scale parameter values for axis labels. None → integer indices.
n_panels : int
Number of panels to show (overridden by t_indices).
t_indices : list of int or None
Specific column indices into H. None → geometrically spaced.
cmap : str
Colourmap for the descriptor field.
log_norm : bool
Use ``LogNorm`` for colour scaling (recommended for HKS).
vmin, vmax : float or None
Colour range. None → robust 2nd/98th percentiles of full H.
subfield_labels : (V,) int array or None
If provided, overlay subfield boundaries via ``tricontour``.
boundary_color : str
Colour for subfield boundary lines.
boundary_width : float
Width of boundary lines.
descriptor_name : str
Label for colourbar (e.g. ``"HKS"``, ``"WKS"``).
xlabel, ylabel : str
Axis labels for the unfolded coordinate axes.
n_cols : int
Number of columns in the panel grid.
figsize : tuple or None
Figure size. None → auto from n_panels.
save : PathLike or None
Returns
-------
(Figure, ndarray of Axes)
"""
_apply_style()
import matplotlib.tri as mtri
from matplotlib.colors import LogNorm, Normalize
H = np.asarray(H, dtype=np.float64)
uv = np.asarray(unfolded_coords, dtype=np.float64)
fcs = np.asarray(faces, dtype=np.int64)
_V, T = H.shape
# --- select scale indices ---
if t_indices is not None:
sel = list(t_indices)
else:
# geometrically spaced indices for log-spaced t
sel = np.unique(np.geomspace(1, T, n_panels, dtype=int).clip(0, T - 1)).tolist()
n_panels = len(sel)
if t_values is None:
t_values = np.arange(T, dtype=np.float64)
t_values = np.asarray(t_values, dtype=np.float64)
# --- build triangulation ---
tri = mtri.Triangulation(uv[:, 0], uv[:, 1], fcs)
# --- colour normalisation (shared across all panels) ---
if vmin is None:
vmin = float(np.nanpercentile(H[:, sel], 2))
if vmax is None:
vmax = float(np.nanpercentile(H[:, sel], 98))
if log_norm:
vmin_safe = max(vmin, 1e-12)
norm = LogNorm(vmin=vmin_safe, vmax=max(vmax, vmin_safe * 2))
else:
norm = Normalize(vmin=vmin, vmax=vmax)
# --- layout ---
n_rows = (n_panels + n_cols - 1) // n_cols
if figsize is None:
# 2:1 aspect ratio per panel (AP is typically ~2× PD)
figsize = (4.0 * n_cols, 2.5 * n_rows)
fig, axes = plt.subplots(
n_rows,
n_cols,
figsize=figsize,
sharex=True,
sharey=True,
constrained_layout=True,
)
axes_flat = np.atleast_1d(axes).ravel()
# --- render each panel ---
tpc = None
for pi, t_idx in enumerate(sel):
ax = axes_flat[pi]
scalar = H[:, t_idx]
tpc = ax.tripcolor(
tri,
scalar,
shading="gouraud",
cmap=cmap,
norm=norm,
rasterized=True, # keeps PDF/SVG file sizes small
)
# subfield boundaries
if subfield_labels is not None:
labels_f = np.asarray(subfield_labels, dtype=np.float64)
ax.tricontour(
tri,
labels_f,
levels=np.arange(labels_f.max()) + 0.5,
colors=boundary_color,
linewidths=boundary_width,
)
# scale label
t_val = t_values[t_idx]
ax.set_title(f"t = {t_val:.2g}", fontsize=8)
ax.set_aspect("equal")
# hide unused panels
for pi in range(n_panels, len(axes_flat)):
axes_flat[pi].set_visible(False)
# axis labels on edge panels only
for ax in axes_flat:
if ax.get_visible():
ax.set_xlabel(xlabel, fontsize=7)
ax.set_ylabel(ylabel, fontsize=7)
ax.tick_params(labelsize=6)
# shared colourbar
if tpc is not None:
cbar = fig.colorbar(
tpc,
ax=axes_flat[:n_panels].tolist(),
location="right",
shrink=0.8,
pad=0.02,
)
cbar.set_label(
f"{descriptor_name} ({'log scale' if log_norm else 'linear'})",
fontsize=8,
)
cbar.ax.tick_params(labelsize=6)
_savefig(fig, save)
return fig, axes
# ======================================================================
# §22 SPATIO-TEMPORAL ANIMATION — GIF / MP4 across scales
# ======================================================================
[docs]
def plot_spatiotemporal_animation(
unfolded_coords: np.ndarray,
faces: np.ndarray,
H: np.ndarray,
t_values: np.ndarray | None = None,
*,
cmap: str = "magma",
log_norm: bool = True,
vmin: float | None = None,
vmax: float | None = None,
subfield_labels: np.ndarray | None = None,
descriptor_name: str = "HKS",
fps: int = 8,
figsize: tuple[float, float] = (6, 3.5),
save: PathLike | None = None,
) -> Any | None:
"""Animate descriptor field across scales on the unfolded surface.
Produces a GIF or MP4 (determined by file extension of ``save``)
showing how HKS/WKS evolves from fine-scale (local curvature)
to coarse-scale (global topology) as the diffusion time increases.
Uses ``FuncAnimation`` with a single ``tripcolor`` artist and
``set_array`` updates for efficient rendering.
Parameters
----------
unfolded_coords : (V, 2) array
faces : (F, 3) array
H : (V, T) array
t_values : (T,) array or None
cmap : str
log_norm : bool
vmin, vmax : float or None
subfield_labels : (V,) int array or None
descriptor_name : str
fps : int
Frames per second for the output animation.
figsize : tuple
save : PathLike or None
Output path. Extension determines format:
``.gif`` → Pillow writer, ``.mp4`` → ffmpeg writer.
Returns
-------
matplotlib.animation.FuncAnimation or None
The animation object (for notebook display). None if save-only.
"""
import matplotlib.animation as animation
import matplotlib.tri as mtri
from matplotlib.colors import LogNorm, Normalize
_apply_style()
H = np.asarray(H, dtype=np.float64)
uv = np.asarray(unfolded_coords, dtype=np.float64)
fcs = np.asarray(faces, dtype=np.int64)
_V, T = H.shape
if t_values is None:
t_values = np.arange(T, dtype=np.float64)
t_values = np.asarray(t_values, dtype=np.float64)
tri = mtri.Triangulation(uv[:, 0], uv[:, 1], fcs)
# --- normalisation ---
if vmin is None:
vmin = float(np.nanpercentile(H, 2))
if vmax is None:
vmax = float(np.nanpercentile(H, 98))
if log_norm:
vmin_safe = max(vmin, 1e-12)
norm = LogNorm(vmin=vmin_safe, vmax=max(vmax, vmin_safe * 2))
else:
norm = Normalize(vmin=vmin, vmax=vmax)
# --- setup figure with first frame ---
fig, ax = plt.subplots(figsize=figsize)
tpc = ax.tripcolor(
tri,
H[:, 0],
shading="gouraud",
cmap=cmap,
norm=norm,
rasterized=True,
)
fig.colorbar(tpc, ax=ax, label=descriptor_name, shrink=0.8)
# static subfield boundaries (drawn once)
if subfield_labels is not None:
labels_f = np.asarray(subfield_labels, dtype=np.float64)
ax.tricontour(
tri,
labels_f,
levels=np.arange(labels_f.max()) + 0.5,
colors="k",
linewidths=0.5,
)
ax.set_aspect("equal")
ax.set_xlabel("AP coordinate", fontsize=8)
ax.set_ylabel("PD coordinate", fontsize=8)
title_text = ax.set_title(
f"{descriptor_name} t = {t_values[0]:.2g}",
fontsize=9,
)
# --- animation update ---
def _update(frame_idx):
"""Update the interactive plot state."""
# tripcolor stores face-averaged values for flat shading
# and vertex values for gouraud — set_array on the collection
tpc.set_array(H[:, frame_idx])
title_text.set_text(f"{descriptor_name} t = {t_values[frame_idx]:.2g}")
return (tpc, title_text)
ani = animation.FuncAnimation(
fig,
_update,
frames=T,
interval=1000 // fps,
blit=False, # tripcolor + blit can produce blank frames
)
if save is not None:
save = Path(save)
save.parent.mkdir(parents=True, exist_ok=True)
ext = save.suffix.lower()
if ext == ".gif":
ani.save(str(save), writer="pillow", fps=fps, dpi=200)
elif ext in (".mp4", ".avi", ".mov"):
ani.save(str(save), writer="ffmpeg", fps=fps, dpi=200)
else:
ani.save(str(save), fps=fps, dpi=200)
logger.info("Saved animation → %s", save)
return ani
# ======================================================================
# §23 HOVMÖLLER DIAGRAM — position × scale 2D heatmap
# ======================================================================
[docs]
def plot_hovmoller(
unfolded_coords: np.ndarray,
H: np.ndarray,
t_values: np.ndarray | None = None,
*,
axis: Literal["AP", "PD"] = "AP",
n_bins: int = 100,
cmap: str = "viridis",
log_norm: bool = True,
log_t: bool = True,
descriptor_name: str = "HKS",
title: str | None = None,
figsize: tuple[float, float] = (8, 4),
save: PathLike | None = None,
) -> tuple[Figure, Axes]:
"""Hovmöller diagram: averaged descriptor along one spatial axis × scale.
Collapses the orthogonal spatial axis by averaging, producing a
2D heatmap of (position along AP or PD) × (diffusion scale t).
Reveals how the multi-scale spectral signature varies along the
hippocampal long axis (AP) or proximal-distal axis (PD).
Parameters
----------
unfolded_coords : (V, 2) array
Column 0 = AP, column 1 = PD.
H : (V, T) array
Per-vertex descriptor matrix.
t_values : (T,) array or None
axis : str
Which spatial axis to retain: ``"AP"`` (column 0) or ``"PD"``
(column 1). The other is averaged out.
n_bins : int
Number of bins along the retained spatial axis.
cmap : str
log_norm : bool
log_t : bool
Log-scale the t-axis.
descriptor_name : str
title : str or None
figsize, save
Returns
-------
(Figure, Axes)
"""
_apply_style()
from matplotlib.colors import LogNorm, Normalize
H = np.asarray(H, dtype=np.float64)
uv = np.asarray(unfolded_coords, dtype=np.float64)
V, T = H.shape
if t_values is None:
t_values = np.arange(T, dtype=np.float64)
t_values = np.asarray(t_values, dtype=np.float64)
# --- select spatial coordinate ---
col_idx = 0 if axis == "AP" else 1
pos = uv[:, col_idx]
# --- bin vertices along the chosen axis ---
bin_edges = np.linspace(pos.min(), pos.max(), n_bins + 1)
bin_centers = 0.5 * (bin_edges[:-1] + bin_edges[1:])
digitized = np.digitize(pos, bin_edges) - 1
digitized = np.clip(digitized, 0, n_bins - 1)
# average H within each spatial bin → (n_bins, T)
H_binned = np.zeros((n_bins, T), dtype=np.float64)
counts = np.zeros(n_bins, dtype=np.float64)
for i in range(V):
b = digitized[i]
H_binned[b] += H[i]
counts[b] += 1.0
counts[counts == 0] = 1.0
H_binned /= counts[:, None]
# --- normalisation ---
vmin = float(np.nanpercentile(H_binned[H_binned > 0], 2)) if log_norm else float(H_binned.min())
vmax = float(np.nanpercentile(H_binned, 98))
if log_norm:
norm = LogNorm(vmin=max(vmin, 1e-12), vmax=max(vmax, vmin * 2))
else:
norm = Normalize(vmin=vmin, vmax=vmax)
# --- plot ---
fig, ax = plt.subplots(figsize=figsize)
mesh_plot = ax.pcolormesh(
t_values,
bin_centers,
H_binned,
shading="auto",
cmap=cmap,
norm=norm,
rasterized=True,
)
if log_t:
ax.set_xscale("log")
ax.set_xlabel("Diffusion scale t (log)" if log_t else "Scale t")
ax.set_ylabel(f"{axis} coordinate")
ax.set_title(title or f"Hovmöller — {descriptor_name} along {axis}")
cbar = fig.colorbar(mesh_plot, ax=ax, pad=0.02)
cbar.set_label(descriptor_name)
fig.tight_layout()
_savefig(fig, save)
return fig, ax
# ======================================================================
# §24 KYMOGRAPH — 1D line through the surface × scale
# ======================================================================
[docs]
def plot_kymograph(
unfolded_coords: np.ndarray,
faces: np.ndarray,
H: np.ndarray,
t_values: np.ndarray | None = None,
*,
line_axis: Literal["AP", "PD"] = "AP",
line_position: float = 0.5,
n_samples: int = 200,
cmap: str = "viridis",
log_norm: bool = True,
log_t: bool = True,
descriptor_name: str = "HKS",
title: str | None = None,
figsize: tuple[float, float] = (8, 4),
save: PathLike | None = None,
) -> tuple[Figure, Axes]:
"""Kymograph: descriptor values along a 1D line × scale.
Unlike the Hovmöller diagram (which averages over the orthogonal
axis), the kymograph traces a single line through the unfolded
surface — e.g. the midline of PD — and plots the descriptor
along that line for each scale t.
Parameters
----------
unfolded_coords : (V, 2) array
faces : (F, 3) array
H : (V, T) array
t_values : (T,) array or None
line_axis : str
Axis along which the line runs. ``"AP"`` → horizontal line
at fixed PD = ``line_position``; ``"PD"`` → vertical line
at fixed AP.
line_position : float
Position on the orthogonal axis (0–1 in normalised coords).
n_samples : int
Number of sample points along the line.
cmap, log_norm, log_t : str, bool, bool
descriptor_name, title, figsize, save
Returns
-------
(Figure, Axes)
"""
_apply_style()
from matplotlib.colors import LogNorm, Normalize
from scipy.interpolate import LinearNDInterpolator
H = np.asarray(H, dtype=np.float64)
uv = np.asarray(unfolded_coords, dtype=np.float64)
_V, T = H.shape
if t_values is None:
t_values = np.arange(T, dtype=np.float64)
t_values = np.asarray(t_values, dtype=np.float64)
# --- build interpolator (Delaunay computed once) ---
interp = LinearNDInterpolator(uv, H[:, 0])
# --- define the 1D line ---
u_range = uv[:, 0]
v_range = uv[:, 1]
if line_axis == "AP":
line_u = np.linspace(u_range.min(), u_range.max(), n_samples)
line_v = np.full_like(line_u, line_position)
spatial_coord = line_u
spatial_label = "AP coordinate"
else:
line_v = np.linspace(v_range.min(), v_range.max(), n_samples)
line_u = np.full_like(line_v, line_position)
spatial_coord = line_v
spatial_label = "PD coordinate"
line_pts = np.column_stack([line_u, line_v])
# --- interpolate each scale ---
kymo = np.empty((n_samples, T), dtype=np.float64)
for k in range(T):
# reuse the same Delaunay but update values
interp.values = H[:, k : k + 1]
kymo[:, k] = interp(line_pts).squeeze()
# fill NaN from outside convex hull with nearest valid
for k in range(T):
nans = np.isnan(kymo[:, k])
if nans.any() and not nans.all():
valid = ~nans
kymo[nans, k] = np.interp(
np.where(nans)[0],
np.where(valid)[0],
kymo[valid, k],
)
# --- plot ---
fig, ax = plt.subplots(figsize=figsize)
vmin_val = float(np.nanpercentile(kymo[kymo > 0], 2)) if log_norm else float(np.nanmin(kymo))
vmax_val = float(np.nanpercentile(kymo, 98))
if log_norm:
norm = LogNorm(vmin=max(vmin_val, 1e-12), vmax=max(vmax_val, vmin_val * 2))
else:
norm = Normalize(vmin=vmin_val, vmax=vmax_val)
mesh_plot = ax.pcolormesh(
t_values,
spatial_coord,
kymo,
shading="auto",
cmap=cmap,
norm=norm,
rasterized=True,
)
if log_t:
ax.set_xscale("log")
ortho_name = "PD" if line_axis == "AP" else "AP"
ax.set_xlabel("Diffusion scale t")
ax.set_ylabel(spatial_label)
ax.set_title(
title
or f"Kymograph — {descriptor_name} along {line_axis} at {ortho_name}={line_position:.2f}"
)
cbar = fig.colorbar(mesh_plot, ax=ax, pad=0.02)
cbar.set_label(descriptor_name)
fig.tight_layout()
_savefig(fig, save)
return fig, ax
# ======================================================================
# §25 WARPED SURFACE — 3D with height = descriptor value (vedo)
# ======================================================================
[docs]
def plot_warped_surface(
unfolded_coords: np.ndarray,
faces: np.ndarray,
scalars: np.ndarray,
*,
warp_factor: float = 0.5,
cmap: str = "magma",
vmin: float | None = None,
vmax: float | None = None,
descriptor_name: str = "HKS",
views: list[str] | None = None,
lighting: str = "default",
bg: str = _DEFAULT_BG,
size: tuple[int, int] | None = None,
scale: int = _DEFAULT_SCALE,
save: PathLike | None = None,
) -> tuple[Path, dict[str, Any]]:
"""3D warped surface: unfolded (u, v) + height from descriptor.
Creates a 3D surface where x = AP, y = PD, and z = warp_factor ×
descriptor_value. The surface is coloured by the same descriptor.
Useful as a "hero figure" showing the spatial distribution of
a spectral descriptor.
Parameters
----------
unfolded_coords : (V, 2) array
faces : (F, 3) array
scalars : (V,) array
Descriptor values to warp by and colour with.
warp_factor : float
Height scaling factor.
cmap : str
vmin, vmax : float or None
descriptor_name : str
views : list of str or None
Camera presets. None → ``["oblique_left"]``.
lighting, bg, size, scale, save
Returns
-------
(Path, dict)
"""
vedo = _get_vedo()
uv = np.asarray(unfolded_coords, dtype=np.float64)
fcs = np.asarray(faces, dtype=np.int64)
sc = np.asarray(scalars, dtype=np.float64)
# normalise scalars for z-displacement
sc_norm = sc - sc.min()
sc_max = sc_norm.max()
if sc_max > 0:
sc_norm /= sc_max
# build 3D vertices: (u, v, warp_factor * normalised_scalar)
verts_3d = np.column_stack(
[
uv[:, 0],
uv[:, 1],
warp_factor * sc_norm,
]
)
if views is None:
views = ["oblique_left"]
n_views = len(views)
if size is None:
size = (600 * n_views, 600)
if vmin is None:
vmin = float(np.nanpercentile(sc, 1))
if vmax is None:
vmax = float(np.nanpercentile(sc, 99))
plt_obj = vedo.Plotter(
shape=(1, n_views),
offscreen=True,
size=size,
bg=bg,
)
for vi, view_name in enumerate(views):
mesh = _build_vedo_mesh(verts_3d, fcs, vedo)
mesh.pointdata[descriptor_name] = sc
mesh.cmap(cmap, descriptor_name, vmin=vmin, vmax=vmax)
mesh.add_scalarbar(title=descriptor_name)
mesh.lighting(lighting)
preset = CAMERA_PRESETS.get(view_name, {"azimuth": -45, "elevation": 30})
plt_obj.at(vi).show(
mesh,
title=f"{descriptor_name} (warped)",
viewup="z",
zoom=1.0,
**{k: v for k, v in preset.items() if k in ("azimuth", "elevation")},
)
meta = {
"warp_factor": warp_factor,
"scalar_range": (vmin, vmax),
"views": views,
}
out = _save_screenshot(plt_obj, save, scale=scale)
return out, meta
# ======================================================================
# §26 DESCRIPTOR EVOLUTION COMPARISON — HKS vs WKS on unfolded
# ======================================================================
[docs]
def plot_descriptor_evolution_comparison(
unfolded_coords: np.ndarray,
faces: np.ndarray,
H_hks: np.ndarray,
H_wks: np.ndarray,
t_values_hks: np.ndarray | None = None,
e_values_wks: np.ndarray | None = None,
*,
n_scales: int = 4,
cmap_hks: str = "inferno",
cmap_wks: str = "cividis",
log_norm: bool = True,
subfield_labels: np.ndarray | None = None,
figsize: tuple[float, float] | None = None,
save: PathLike | None = None,
) -> tuple[Figure, np.ndarray]:
"""Side-by-side HKS vs WKS evolution on the unfolded surface.
Two-row layout: top = HKS at selected t, bottom = WKS at matched
energies. Each with its own colourmap and independent colourbar.
Parameters
----------
unfolded_coords : (V, 2) array
faces : (F, 3) array
H_hks : (V, T_h) array
H_wks : (V, T_w) array
t_values_hks : (T_h,) or None
e_values_wks : (T_w,) or None
n_scales : int
Number of scale panels per descriptor.
cmap_hks, cmap_wks : str
log_norm : bool
subfield_labels : (V,) or None
figsize, save
Returns
-------
(Figure, ndarray of Axes)
"""
_apply_style()
import matplotlib.tri as mtri
from matplotlib.colors import LogNorm, Normalize
H_h = np.asarray(H_hks, dtype=np.float64)
H_w = np.asarray(H_wks, dtype=np.float64)
uv = np.asarray(unfolded_coords, dtype=np.float64)
fcs = np.asarray(faces, dtype=np.int64)
tri = mtri.Triangulation(uv[:, 0], uv[:, 1], fcs)
T_h = H_h.shape[1]
T_w = H_w.shape[1]
sel_h = np.unique(np.geomspace(1, T_h, n_scales, dtype=int).clip(0, T_h - 1))
sel_w = np.unique(np.geomspace(1, T_w, n_scales, dtype=int).clip(0, T_w - 1))
n_h = len(sel_h)
n_w = len(sel_w)
n_cols = max(n_h, n_w)
if t_values_hks is None:
t_values_hks = np.arange(T_h, dtype=np.float64)
if e_values_wks is None:
e_values_wks = np.arange(T_w, dtype=np.float64)
t_h = np.asarray(t_values_hks)
e_w = np.asarray(e_values_wks)
if figsize is None:
figsize = (4.0 * n_cols, 5.0)
fig, axes = plt.subplots(
2, n_cols, figsize=figsize, sharex=True, sharey=True, constrained_layout=True
)
if n_cols == 1:
axes = axes.reshape(2, 1)
# --- helper to get norm ---
def _make_norm(H_block, indices):
"""Create a matplotlib Normalize instance for the given range."""
vals = H_block[:, indices]
vmin = float(np.nanpercentile(vals[vals > 0], 2)) if log_norm else float(vals.min())
vmax = float(np.nanpercentile(vals, 98))
if log_norm:
return LogNorm(vmin=max(vmin, 1e-12), vmax=max(vmax, vmin * 2))
return Normalize(vmin=vmin, vmax=vmax)
norm_h = _make_norm(H_h, sel_h)
norm_w = _make_norm(H_w, sel_w)
tpc_h = tpc_w = None
# --- top row: HKS ---
for pi, ti in enumerate(sel_h):
ax = axes[0, pi]
tpc_h = ax.tripcolor(
tri, H_h[:, ti], shading="gouraud", cmap=cmap_hks, norm=norm_h, rasterized=True
)
ax.set_title(f"t = {t_h[ti]:.2g}", fontsize=7)
ax.set_aspect("equal")
if subfield_labels is not None:
ax.tricontour(
tri,
subfield_labels.astype(float),
levels=np.arange(subfield_labels.max()) + 0.5,
colors="k",
linewidths=0.4,
)
# --- bottom row: WKS ---
for pi, ei in enumerate(sel_w):
ax = axes[1, pi]
tpc_w = ax.tripcolor(
tri, H_w[:, ei], shading="gouraud", cmap=cmap_wks, norm=norm_w, rasterized=True
)
ax.set_title(f"E = {e_w[ei]:.2g}", fontsize=7)
ax.set_aspect("equal")
if subfield_labels is not None:
ax.tricontour(
tri,
subfield_labels.astype(float),
levels=np.arange(subfield_labels.max()) + 0.5,
colors="k",
linewidths=0.4,
)
# hide unused panels
for pi in range(n_h, n_cols):
axes[0, pi].set_visible(False)
for pi in range(n_w, n_cols):
axes[1, pi].set_visible(False)
# row labels
axes[0, 0].set_ylabel("HKS\nPD", fontsize=8)
axes[1, 0].set_ylabel("WKS\nPD", fontsize=8)
# colourbars
if tpc_h is not None:
fig.colorbar(
tpc_h, ax=axes[0, :].tolist(), location="right", shrink=0.7, pad=0.02, label="HKS"
)
if tpc_w is not None:
fig.colorbar(
tpc_w, ax=axes[1, :].tolist(), location="right", shrink=0.7, pad=0.02, label="WKS"
)
_savefig(fig, save)
return fig, axes