"""Spectral Graph Wavelets (SGW) for multi-resolution shape analysis.
Implements the wavelet framework of Hammond, Vandergheynst &
Gribonval (ACHA 2011): define a band-pass filter g on the spectral
domain and apply it via Chebyshev polynomial approximation of
g(t·L), **without** explicit eigendecomposition.
Also supports wavelet computation from precomputed eigenpairs (faster
for the same decomposition used by HKS/WKS).
Kernels
-------
- **Mexican hat** — g(x) = x · exp(-x) — the canonical SGW kernel.
- **Heat** — g(x) = exp(-x) — low-pass (same filter as HKS).
- **Meyer** — compactly supported in spectral domain.
- **Custom** — user-defined callable g(x).
"""
from __future__ import annotations
from collections.abc import Callable
from typing import Literal
import numpy as np
import scipy.sparse as sp
from spectralbrain.core.base import SpectralDecomposition
from spectralbrain.runtime import (
DescriptorMatrix,
SparseMatrix,
get_logger,
progress_simple,
)
logger = get_logger(__name__)
# ======================================================================
# §1 WAVELET KERNELS
# ======================================================================
[docs]
def mexican_hat_kernel(x: np.ndarray) -> np.ndarray:
"""Mexican-hat (Ricker) wavelet kernel: g(x) = x · exp(-x).
Parameters
----------
x : ndarray
Scaled spectral variable t·λ.
Returns
-------
ndarray
"""
return x * np.exp(-x)
[docs]
def heat_kernel(x: np.ndarray) -> np.ndarray:
"""Heat kernel: g(x) = exp(-x).
Low-pass filter — equivalent to the HKS filter but used here
in the wavelet framework for completeness.
"""
return np.exp(-x)
[docs]
def meyer_kernel(x: np.ndarray) -> np.ndarray:
"""Simplified Meyer-type wavelet kernel.
Compactly supported band-pass: peaks at x ≈ 1, decays to zero
at x = 0 and x → ∞.
"""
v = np.zeros_like(x)
mask1 = (x >= 0.5) & (x < 1.0)
mask2 = (x >= 1.0) & (x < 2.0)
v[mask1] = np.sin(np.pi / 2 * _nu(2 * x[mask1] - 1)) ** 2
v[mask2] = np.cos(np.pi / 2 * _nu(x[mask2] - 1)) ** 2
return v
def _nu(x: np.ndarray) -> np.ndarray:
"""Smooth transition function for Meyer wavelet."""
return x**4 * (35 - 84 * x + 70 * x**2 - 20 * x**3)
# ======================================================================
# §2 CHEBYSHEV APPROXIMATION OF g(t·L)
# ======================================================================
def _chebyshev_coefficients(
kernel: Callable[[np.ndarray], np.ndarray],
K: int,
*,
a: float = 0.0,
b: float = 2.0,
) -> np.ndarray:
"""Compute Chebyshev coefficients for a kernel on [a, b].
Uses the discrete cosine transform of kernel samples at
Chebyshev nodes.
Parameters
----------
kernel : callable
g(x) → y, defined on [a, b].
K : int
Number of Chebyshev coefficients (polynomial order + 1).
a, b : float
Interval bounds.
Returns
-------
ndarray, shape (K,)
"""
N = max(K + 1, 2 * K)
nodes = np.cos(np.pi * (np.arange(N) + 0.5) / N) # in [-1, 1]
x = (nodes + 1) * (b - a) / 2 + a # map to [a, b]
vals = kernel(x)
# DCT-based coefficient estimation.
coeffs = np.zeros(K, dtype=np.float64)
for k in range(K):
coeffs[k] = (2.0 / N) * np.sum(vals * np.cos(np.pi * k * (np.arange(N) + 0.5) / N))
coeffs[0] /= 2.0
return coeffs
def _chebyshev_apply(
L: SparseMatrix,
signal: np.ndarray,
coeffs: np.ndarray,
*,
a: float = 0.0,
b: float = 2.0,
) -> np.ndarray:
"""Apply Chebyshev polynomial approximation of g(L) to a signal.
Uses the three-term recurrence T_{k+1}(x) = 2x·T_k(x) − T_{k-1}(x).
Parameters
----------
L : SparseMatrix, shape (N, N)
signal : ndarray, shape (N,) or (N, d)
coeffs : ndarray, shape (K,)
a, b : float
Spectral bounds.
Returns
-------
ndarray, same shape as signal
"""
L.shape[0]
K = len(coeffs)
# Scale L to [-1, 1]: L̃ = (2L − (a+b)I) / (b−a)
c = 2.0 / (b - a)
d = -(a + b) / (b - a)
# T_0 = I · signal
T_prev = signal.copy() # T_0(L̃) · s
result = coeffs[0] * T_prev
if K == 1:
return result
# T_1 = L̃ · signal
T_curr = c * (L @ signal) + d * signal
result = result + coeffs[1] * T_curr
# Recurrence.
for k in range(2, K):
T_next = 2.0 * (c * (L @ T_curr) + d * T_curr) - T_prev
result = result + coeffs[k] * T_next
T_prev = T_curr
T_curr = T_next
return result
# ======================================================================
# §3 SPECTRAL GRAPH WAVELET TRANSFORM
# ======================================================================
# ======================================================================
# §4 WAVELET DESCRIPTORS FROM EIGENPAIRS
# ======================================================================
[docs]
def sgw_descriptor(
decomp: SpectralDecomposition,
scales: np.ndarray | None = None,
*,
n_scales: int = 5,
kernel: Callable = mexican_hat_kernel,
aggregate: Literal["energy", "raw", "abs_mean"] = "energy",
) -> DescriptorMatrix:
"""Spectral Graph Wavelet descriptor from precomputed eigenpairs.
Faster than Chebyshev-based SGW when the eigenpairs are already
available (from HKS/WKS computation).
.. math::
\\psi_{t}(x) = \\sum_{i=0}^{k}
g(t \\cdot \\lambda_i)\\, \\varphi_i(x)
The per-vertex wavelet energy at scale *t* is:
.. math::
W(x, t) = \\psi_t^2(x) = \\left(
\\sum_i g(t \\lambda_i) \\varphi_i(x)
\\right)^2
Parameters
----------
decomp : SpectralDecomposition
scales : ndarray, shape (S,), optional
Wavelet scales. ``None`` = auto log-spaced.
n_scales : int
Number of auto scales.
kernel : callable
Wavelet kernel g(x).
aggregate : str
``"energy"`` — ψ²(x, t), wavelet energy per vertex per scale.
``"raw"`` — ψ(x, t), raw wavelet coefficients (signed).
``"abs_mean"`` — |ψ(x, t)|, absolute coefficients.
Returns
-------
ndarray, shape (N, S)
Multi-scale wavelet descriptor.
"""
evals = decomp.eigenvalues
evecs = decomp.eigenvectors
if scales is None:
lam_nz = evals[evals > 1e-10]
if len(lam_nz) < 2:
scales = np.logspace(-1, 2, n_scales)
else:
s_min = 1.0 / lam_nz[-1]
s_max = 2.0 / lam_nz[0]
scales = np.logspace(np.log10(s_min), np.log10(s_max), n_scales)
scales = np.asarray(scales, dtype=np.float64)
len(scales)
# g(t·λ) for each scale: (S, k)
g_tl = np.array([kernel(t * evals) for t in scales]) # (S, k)
# ψ_t(x) = Σᵢ g(t·λᵢ)·φᵢ(x) = Φ @ g_tl.T
psi = evecs @ g_tl.T # (N, S)
if aggregate == "energy":
return psi**2
elif aggregate == "abs_mean":
return np.abs(psi)
elif aggregate == "raw":
return psi
else:
raise ValueError(f"Unknown aggregate: {aggregate!r}")
# ======================================================================
__all__: list[str] = [
"heat_kernel",
# Kernels
"mexican_hat_kernel",
"meyer_kernel",
# Eigenpair-based
"sgw_descriptor",
# Chebyshev
"sgw_transform",
]