"""3D mesh rendering with vedo and optional PyVista fallback.
This module handles publication-quality 3D renders of triangular meshes
for SpectralBrain's mesh-based analysis pathway. It complements
``points.py`` (which handles atlas-free point clouds) by covering
scenarios where mesh connectivity *is* available — either from
FreeSurfer surface files, HippUnfold outputs, or reconstructed from
point clouds via Poisson / Delaunay.
The six figure types cover the critical mesh visual outputs:
1. **Surface render** — smooth-shaded mesh with optional scalar
overlay (HKS, WKS, thickness, curvature).
2. **Wireframe render** — mesh topology visualisation for QC and
methods figures.
3. **Curvature map** — Gaussian, mean, principal curvatures computed
and displayed directly on the mesh surface.
4. **Multi-view panel** — same mesh from multiple camera angles
(anterior, posterior, lateral, medial, superior, inferior).
5. **Mesh comparison** — side-by-side panels comparing two or more
meshes (e.g., left vs right hemisphere, patient vs control).
6. **Scalar difference map** — vertex-wise difference between two
meshes overlaid as a diverging colourmap.
All functions follow the SpectralBrain convention: return
``(Path, metadata_dict)`` for vedo-based renders.
"""
from __future__ import annotations
import os
import tempfile
from pathlib import Path
from typing import TYPE_CHECKING, Any
import numpy as np
from spectralbrain.runtime import PathLike, get_logger
if TYPE_CHECKING:
import vedo
logger = get_logger(__name__)
# ---------------------------------------------------------------------------
# Constants — shared with points.py
# ---------------------------------------------------------------------------
_DEFAULT_SIZE: tuple[int, int] = (1600, 1200)
_DEFAULT_SCALE: int = 2
_DEFAULT_BG: str = "white"
# Curvature method codes used by VTK / vedo
CURVATURE_METHODS: dict[str, int] = {
"gaussian": 0,
"mean": 1,
"maximum": 2,
"minimum": 3,
}
# Standard multi-view camera presets (azimuth, elevation)
CAMERA_PRESETS: dict[str, dict[str, Any]] = {
"anterior": {"azimuth": 0, "elevation": 0},
"posterior": {"azimuth": 180, "elevation": 0},
"left_lateral": {"azimuth": -90, "elevation": 0},
"right_lateral": {"azimuth": 90, "elevation": 0},
"superior": {"azimuth": 0, "elevation": 90},
"inferior": {"azimuth": 0, "elevation": -90},
"left_medial": {"azimuth": 90, "elevation": 0},
"right_medial": {"azimuth": -90, "elevation": 0},
}
# ======================================================================
# §0 Lazy imports & helpers
# ======================================================================
def _ensure_offscreen() -> None:
"""Set vedo to offscreen rendering mode."""
os.environ.setdefault("VTK_USE_OFFSCREEN", "1")
def _get_vedo():
"""Lazy-import vedo, raising ImportError if unavailable."""
_ensure_offscreen()
try:
import vedo
try:
vedo.start_xvfb()
except Exception:
pass
return vedo
except ImportError:
raise ImportError(
"vedo is required for mesh visualization. Install with: pip install vedo"
)
def _save_screenshot(plotter, save: PathLike | None, *, scale: int = _DEFAULT_SCALE) -> Path:
"""Capture a vedo Plotter to PNG and close it."""
if save is None:
fd, save = tempfile.mkstemp(suffix=".png")
os.close(fd)
save = Path(save)
save.parent.mkdir(parents=True, exist_ok=True)
plotter.screenshot(str(save), scale=scale)
plotter.close()
logger.info("Saved mesh render → %s", save)
return save
def _build_vedo_mesh(
vertices: np.ndarray,
faces: np.ndarray,
vedo_module,
) -> vedo.Mesh:
"""Construct a vedo Mesh from numpy arrays.
Parameters
----------
vertices : (V, 3) array
faces : (F, 3) array of int indices
vedo_module : the vedo module (passed to avoid re-import)
Returns
-------
vedo.Mesh
"""
vertices = np.asarray(vertices, dtype=np.float64)
faces = np.asarray(faces, dtype=int)
assert vertices.ndim == 2 and vertices.shape[1] == 3
assert faces.ndim == 2 and faces.shape[1] == 3
mesh = vedo_module.Mesh([vertices, faces])
return mesh
def _resolve_cmap(scalar_name: str | None, cmap: str | None) -> str:
"""Pick colourmap: explicit > name-based > viridis."""
if cmap is not None:
return cmap
LOOKUP = {
"hks": "inferno",
"wks": "cividis",
"bks": "magma",
"gps": "viridis",
"shapedna": "plasma",
"curvature": "RdBu_r",
"mean": "RdBu_r",
"gaussian": "RdBu_r",
"thickness": "YlOrRd",
"z_score": "RdBu_r",
"difference": "RdBu_r",
}
if scalar_name is not None:
key = scalar_name.lower().replace(" ", "_").split("_")[0]
return LOOKUP.get(key, "viridis")
return "viridis"
# ======================================================================
# §1 Surface render — smooth-shaded mesh with scalar overlay
# ======================================================================
[docs]
def plot_mesh(
vertices: np.ndarray,
faces: np.ndarray,
scalars: np.ndarray | None = None,
scalar_name: str = "HKS",
cmap: str | None = None,
vmin: float | None = None,
vmax: float | None = None,
color: str = "gold",
alpha: float = 1.0,
show_edges: bool = False,
edge_color: str = "gray",
edge_width: float = 0.3,
show_scalarbar: bool = True,
lighting: str = "default",
camera: dict[str, Any] | None = None,
title: str | None = None,
bg: str = _DEFAULT_BG,
size: tuple[int, int] = _DEFAULT_SIZE,
scale: int = _DEFAULT_SCALE,
save: PathLike | None = None,
) -> tuple[Path, dict[str, Any]]:
"""Render a triangular mesh with optional scalar overlay.
This is the primary mesh visualisation: a smooth Phong-shaded
surface optionally coloured by a per-vertex spectral descriptor,
morphometric measure, or statistical map.
Parameters
----------
vertices : (V, 3) array
Mesh vertex coordinates.
faces : (F, 3) array
Triangle index array.
scalars : (V,) array or None
Per-vertex scalar values. None → uniform ``color``.
scalar_name : str
Label for colourbar and automatic cmap selection.
cmap : str or None
Colourmap. None → auto from scalar_name.
vmin, vmax : float or None
Colour range. None → 1st / 99th percentiles.
color : str
Uniform mesh colour when scalars is None.
alpha : float
Mesh opacity (0–1).
show_edges : bool
Overlay wireframe edges.
edge_color, edge_width : str, float
Edge appearance.
show_scalarbar : bool
Display colourbar.
lighting : str
VTK lighting style — ``'default'``, ``'metallic'``,
``'plastic'``, ``'shiny'``, ``'glossy'``.
camera : dict or None
Camera configuration (``pos``, ``focal_point``, ``viewup``).
title : str or None
Figure title.
bg, size, scale, save
Standard render parameters.
Returns
-------
(Path, dict)
PNG path and metadata with ``'n_vertices'``, ``'n_faces'``,
``'scalar_range'``, ``'cmap'``.
"""
vedo = _get_vedo()
mesh = _build_vedo_mesh(vertices, faces, vedo)
cmap_name = _resolve_cmap(scalar_name, cmap)
meta: dict[str, Any] = {
"n_vertices": vertices.shape[0],
"n_faces": faces.shape[0],
"cmap": cmap_name,
"scalar_range": None,
}
if scalars is not None:
scalars = np.asarray(scalars, dtype=np.float64)
assert scalars.shape[0] == vertices.shape[0], (
f"scalars ({scalars.shape[0]}) must match vertices ({vertices.shape[0]})"
)
if vmin is None:
vmin = float(np.nanpercentile(scalars, 1))
if vmax is None:
vmax = float(np.nanpercentile(scalars, 99))
mesh.pointdata[scalar_name] = scalars
mesh.cmap(cmap_name, scalar_name, vmin=vmin, vmax=vmax)
if show_scalarbar:
mesh.add_scalarbar(title=scalar_name)
meta["scalar_range"] = (vmin, vmax)
else:
mesh.color(color)
mesh.alpha(alpha)
mesh.lighting(lighting)
if show_edges:
mesh.linewidth(edge_width).linecolor(edge_color)
plt = vedo.Plotter(offscreen=True, size=size, bg=bg, title=title or "")
show_kw: dict[str, Any] = {"viewup": "z", "zoom": 1.2}
if camera is not None:
show_kw["camera"] = camera
plt.show(mesh, **show_kw)
out = _save_screenshot(plt, save, scale=scale)
return out, meta
# ======================================================================
# §2 Wireframe render
# ======================================================================
[docs]
def plot_wireframe(
vertices: np.ndarray,
faces: np.ndarray,
*,
color: str = "steelblue",
linewidth: float = 0.5,
alpha: float = 1.0,
camera: dict[str, Any] | None = None,
title: str | None = None,
bg: str = _DEFAULT_BG,
size: tuple[int, int] = _DEFAULT_SIZE,
scale: int = _DEFAULT_SCALE,
save: PathLike | None = None,
) -> tuple[Path, dict[str, Any]]:
"""Wireframe render of a mesh for topology inspection.
Useful for QC of reconstructed surfaces and for methods figures
that need to show mesh structure clearly.
Parameters
----------
vertices : (V, 3) array
faces : (F, 3) array
color : str
Wire colour.
linewidth : float
Wire thickness.
alpha : float
Opacity.
camera, title, bg, size, scale, save
Standard render parameters.
Returns
-------
(Path, dict)
PNG path and metadata.
"""
vedo = _get_vedo()
mesh = _build_vedo_mesh(vertices, faces, vedo)
mesh.wireframe(True).color(color).linewidth(linewidth).alpha(alpha)
plt = vedo.Plotter(offscreen=True, size=size, bg=bg, title=title or "")
show_kw: dict[str, Any] = {"viewup": "z", "zoom": 1.2}
if camera is not None:
show_kw["camera"] = camera
plt.show(mesh, **show_kw)
meta = {"n_vertices": vertices.shape[0], "n_faces": faces.shape[0]}
out = _save_screenshot(plt, save, scale=scale)
return out, meta
# ======================================================================
# §3 Curvature map
# ======================================================================
[docs]
def plot_curvature(
vertices: np.ndarray,
faces: np.ndarray,
method: str = "mean",
*,
cmap: str = "RdBu_r",
vmin: float | None = None,
vmax: float | None = None,
symmetric: bool = True,
title: str | None = None,
bg: str = _DEFAULT_BG,
size: tuple[int, int] = _DEFAULT_SIZE,
scale: int = _DEFAULT_SCALE,
save: PathLike | None = None,
) -> tuple[Path, dict[str, Any]]:
"""Compute and render curvature on a mesh surface.
Computes curvature using VTK's built-in estimator and immediately
displays it with a diverging colourmap centred on zero.
Parameters
----------
vertices : (V, 3) array
faces : (F, 3) array
method : {'gaussian', 'mean', 'maximum', 'minimum'}
Curvature type.
cmap : str
Colourmap (diverging recommended for curvature).
vmin, vmax : float or None
Colour range. If *symmetric* is True and these are None,
range is set to ± 95th percentile.
symmetric : bool
Centre the colourmap on zero.
title, bg, size, scale, save
Standard render parameters.
Returns
-------
(Path, dict)
PNG path and metadata with ``'curvature_method'``,
``'curvature_stats'`` (mean, std, min, max).
"""
vedo = _get_vedo()
mesh = _build_vedo_mesh(vertices, faces, vedo)
method_code = CURVATURE_METHODS.get(method.lower())
if method_code is None:
raise ValueError(
f"Unknown curvature method '{method}'. Choose from: {list(CURVATURE_METHODS.keys())}"
)
mesh.compute_curvature(method=method_code)
# VTK names the array generically; retrieve it
curv = mesh.pointdata["Curvature"]
curv_clean = curv[np.isfinite(curv)]
# Auto colour range
if vmin is None or vmax is None:
p95 = float(np.percentile(np.abs(curv_clean), 95))
if symmetric:
vmin = vmin if vmin is not None else -p95
vmax = vmax if vmax is not None else p95
else:
vmin = vmin if vmin is not None else float(np.percentile(curv_clean, 1))
vmax = vmax if vmax is not None else float(np.percentile(curv_clean, 99))
label = f"{method.capitalize()} curvature"
mesh.cmap(cmap, "Curvature", vmin=vmin, vmax=vmax)
mesh.add_scalarbar(title=label)
plt = vedo.Plotter(offscreen=True, size=size, bg=bg, title=title or label)
plt.show(mesh, viewup="z", zoom=1.2)
meta = {
"curvature_method": method,
"curvature_stats": {
"mean": float(np.mean(curv_clean)),
"std": float(np.std(curv_clean)),
"min": float(np.min(curv_clean)),
"max": float(np.max(curv_clean)),
},
"vmin": vmin,
"vmax": vmax,
}
out = _save_screenshot(plt, save, scale=scale)
return out, meta
# ======================================================================
# §4 Multi-view panel — same mesh from multiple camera angles
# ======================================================================
[docs]
def plot_multi_view(
vertices: np.ndarray,
faces: np.ndarray,
scalars: np.ndarray | None = None,
scalar_name: str = "HKS",
cmap: str | None = None,
vmin: float | None = None,
vmax: float | None = None,
views: list[str] | None = None,
*,
color: str = "gold",
lighting: str = "default",
bg: str = _DEFAULT_BG,
size: tuple[int, int] | None = None,
scale: int = _DEFAULT_SCALE,
save: PathLike | None = None,
) -> tuple[Path, dict[str, Any]]:
"""Multi-view panel showing the same mesh from different angles.
Renders the same mesh (optionally with scalar overlay) in a 1×N
panel strip. Standard views: anterior, posterior, lateral,
medial, superior, inferior.
Parameters
----------
vertices : (V, 3) array
faces : (F, 3) array
scalars : (V,) array or None
scalar_name : str
cmap : str or None
vmin, vmax : float or None
views : list of str or None
Camera preset names from ``CAMERA_PRESETS``. None defaults
to ``['left_lateral', 'anterior', 'superior', 'right_lateral']``.
color : str
Uniform colour when scalars is None.
lighting : str
VTK lighting preset.
bg, size, scale, save
Standard render parameters.
Returns
-------
(Path, dict)
PNG path and metadata.
"""
vedo = _get_vedo()
if views is None:
views = ["left_lateral", "anterior", "superior", "right_lateral"]
n_views = len(views)
if size is None:
size = (600 * n_views, 600)
cmap_name = _resolve_cmap(scalar_name, cmap)
# Build the base mesh once, then clone per view
base = _build_vedo_mesh(vertices, faces, vedo)
if scalars is not None:
scalars = np.asarray(scalars, dtype=np.float64)
if vmin is None:
vmin = float(np.nanpercentile(scalars, 1))
if vmax is None:
vmax = float(np.nanpercentile(scalars, 99))
base.pointdata[scalar_name] = scalars
base.cmap(cmap_name, scalar_name, vmin=vmin, vmax=vmax)
base.add_scalarbar(title=scalar_name)
else:
base.color(color)
base.lighting(lighting)
plt = vedo.Plotter(
shape=(1, n_views),
offscreen=True,
size=size,
bg=bg,
)
for i, view_name in enumerate(views):
m = base.clone()
preset = CAMERA_PRESETS.get(view_name, {})
plt.at(i).show(
m,
title=view_name.replace("_", " ").title(),
viewup="z",
zoom=1.1,
**{k: v for k, v in preset.items() if k in ("azimuth", "elevation")},
)
meta = {
"n_vertices": vertices.shape[0],
"n_faces": faces.shape[0],
"views": views,
"scalar_range": (vmin, vmax) if scalars is not None else None,
}
out = _save_screenshot(plt, save, scale=scale)
return out, meta
# ======================================================================
# §5 Mesh comparison — side-by-side panels
# ======================================================================
[docs]
def plot_mesh_comparison(
meshes: list[dict[str, Any]],
*,
shape: tuple[int, int] | None = None,
bg: str = _DEFAULT_BG,
size: tuple[int, int] | None = None,
scale: int = _DEFAULT_SCALE,
save: PathLike | None = None,
) -> tuple[Path, dict[str, Any]]:
"""Side-by-side comparison of multiple meshes.
Each element in *meshes* is a dict with keys:
- ``'vertices'`` : (V, 3) array (required)
- ``'faces'`` : (F, 3) array (required)
- ``'scalars'`` : (V,) array or None
- ``'scalar_name'`` : str (default ``'value'``)
- ``'cmap'`` : str or None
- ``'vmin'``, ``'vmax'`` : float or None
- ``'color'`` : str (default ``'gold'``)
- ``'title'`` : str (default ``''``)
Parameters
----------
meshes : list of dict
One dict per mesh panel.
shape : (rows, cols) or None
Grid layout. None → single row.
bg, size, scale, save
Standard render parameters.
Returns
-------
(Path, dict)
PNG path and metadata with ``'n_panels'``.
"""
vedo = _get_vedo()
n = len(meshes)
if shape is None:
shape = (1, n)
if size is None:
size = (600 * shape[1], 600 * shape[0])
plt = vedo.Plotter(shape=shape, offscreen=True, size=size, bg=bg)
for i, spec in enumerate(meshes):
m = _build_vedo_mesh(
np.asarray(spec["vertices"]),
np.asarray(spec["faces"]),
vedo,
)
scalars = spec.get("scalars")
scalar_name = spec.get("scalar_name", "value")
cmap_name = _resolve_cmap(scalar_name, spec.get("cmap"))
panel_title = spec.get("title", "")
if scalars is not None:
scalars = np.asarray(scalars, dtype=np.float64)
v0 = spec.get("vmin") or float(np.nanpercentile(scalars, 1))
v1 = spec.get("vmax") or float(np.nanpercentile(scalars, 99))
m.pointdata[scalar_name] = scalars
m.cmap(cmap_name, scalar_name, vmin=v0, vmax=v1)
m.add_scalarbar(title=scalar_name)
else:
m.color(spec.get("color", "gold"))
plt.at(i).show(m, title=panel_title, viewup="z", zoom=1.1)
meta = {"n_panels": n, "shape": shape}
out = _save_screenshot(plt, save, scale=scale)
return out, meta
# ======================================================================
# §6 Scalar difference map
# ======================================================================
[docs]
def plot_scalar_difference(
vertices: np.ndarray,
faces: np.ndarray,
scalars_a: np.ndarray,
scalars_b: np.ndarray,
*,
label_a: str = "A",
label_b: str = "B",
diff_cmap: str = "RdBu_r",
symmetric: bool = True,
show_individual: bool = True,
individual_cmap: str | None = None,
bg: str = _DEFAULT_BG,
size: tuple[int, int] | None = None,
scale: int = _DEFAULT_SCALE,
save: PathLike | None = None,
) -> tuple[Path, dict[str, Any]]:
"""Vertex-wise scalar difference map between two conditions.
Computes ``scalars_a - scalars_b`` and displays the difference
on the mesh surface with a diverging colourmap centred on zero.
Optionally shows individual maps alongside.
Parameters
----------
vertices : (V, 3) array
faces : (F, 3) array
scalars_a, scalars_b : (V,) arrays
Per-vertex values for conditions A and B.
label_a, label_b : str
Labels for panels.
diff_cmap : str
Colourmap for the difference (diverging recommended).
symmetric : bool
Centre the difference colourmap on zero.
show_individual : bool
Show A and B alongside the difference (3-panel layout).
individual_cmap : str or None
Colourmap for individual panels. None → 'viridis'.
bg, size, scale, save
Standard render parameters.
Returns
-------
(Path, dict)
PNG path and metadata with ``'diff_stats'``.
"""
vedo = _get_vedo()
scalars_a = np.asarray(scalars_a, dtype=np.float64)
scalars_b = np.asarray(scalars_b, dtype=np.float64)
diff = scalars_a - scalars_b
n_panels = 3 if show_individual else 1
if size is None:
size = (600 * n_panels, 600)
plt = vedo.Plotter(
shape=(1, n_panels),
offscreen=True,
size=size,
bg=bg,
)
panel_idx = 0
ind_cmap = individual_cmap or "viridis"
if show_individual:
# Panel A
m_a = _build_vedo_mesh(vertices, faces, vedo)
m_a.pointdata[label_a] = scalars_a
m_a.cmap(ind_cmap, label_a)
m_a.add_scalarbar(title=label_a)
plt.at(0).show(m_a, title=label_a, viewup="z", zoom=1.1)
# Panel B
m_b = _build_vedo_mesh(vertices, faces, vedo)
m_b.pointdata[label_b] = scalars_b
m_b.cmap(ind_cmap, label_b)
m_b.add_scalarbar(title=label_b)
plt.at(1).show(m_b, title=label_b, viewup="z", zoom=1.1)
panel_idx = 2
# Difference panel
m_diff = _build_vedo_mesh(vertices, faces, vedo)
m_diff.pointdata["Difference"] = diff
diff_clean = diff[np.isfinite(diff)]
if symmetric:
p95 = float(np.percentile(np.abs(diff_clean), 95))
d_vmin, d_vmax = -p95, p95
else:
d_vmin = float(np.percentile(diff_clean, 1))
d_vmax = float(np.percentile(diff_clean, 99))
m_diff.cmap(diff_cmap, "Difference", vmin=d_vmin, vmax=d_vmax)
m_diff.add_scalarbar(title=f"{label_a} − {label_b}")
plt.at(panel_idx).show(
m_diff,
title=f"Difference ({label_a} − {label_b})",
viewup="z",
zoom=1.1,
)
meta = {
"diff_stats": {
"mean": float(np.nanmean(diff)),
"std": float(np.nanstd(diff)),
"min": float(np.nanmin(diff)),
"max": float(np.nanmax(diff)),
"pct_positive": float(np.mean(diff > 0) * 100),
},
"vmin": d_vmin,
"vmax": d_vmax,
"n_panels": n_panels,
}
out = _save_screenshot(plt, save, scale=scale)
return out, meta
# ======================================================================
# §7 PyVista fallback — basic mesh render
# ======================================================================
[docs]
def plot_mesh_pyvista(
vertices: np.ndarray,
faces: np.ndarray,
scalars: np.ndarray | None = None,
cmap: str = "viridis",
*,
show_edges: bool = False,
window_size: tuple[int, int] = (1600, 1200),
save: PathLike | None = None,
) -> Path | None:
"""Minimal PyVista mesh render (fallback when vedo unavailable).
Parameters
----------
vertices : (V, 3) array
faces : (F, 3) array
scalars : (V,) array or None
cmap : str
show_edges : bool
window_size : (int, int)
save : path or None
Returns
-------
Path or None
Output path if successful, None otherwise.
"""
try:
import pyvista as pv
except ImportError:
logger.warning("PyVista not available — cannot render mesh")
return None
pv.OFF_SCREEN = True
vertices = np.asarray(vertices, dtype=np.float64)
faces = np.asarray(faces, dtype=int)
# PyVista expects faces as [3, i, j, k, 3, i, j, k, ...]
pv_faces = np.column_stack([np.full(len(faces), 3, dtype=int), faces]).ravel()
mesh = pv.PolyData(vertices, pv_faces)
if scalars is not None:
mesh.point_data["scalars"] = np.asarray(scalars, dtype=np.float64)
plotter = pv.Plotter(off_screen=True, window_size=window_size)
plotter.add_mesh(
mesh,
scalars="scalars" if scalars is not None else None,
cmap=cmap,
show_edges=show_edges,
)
plotter.view_isometric()
if save is None:
fd, save = tempfile.mkstemp(suffix=".png")
os.close(fd)
save = Path(save)
plotter.screenshot(str(save))
plotter.close()
logger.info("Saved PyVista render → %s", save)
return save
# ======================================================================
# __all__
# ======================================================================
__all__ = [
"CAMERA_PRESETS",
# Constants
"CURVATURE_METHODS",
"plot_curvature",
# Core renders
"plot_mesh",
"plot_mesh_comparison",
# PyVista fallback
"plot_mesh_pyvista",
"plot_multi_view",
"plot_scalar_difference",
"plot_wireframe",
]