scattering_stats — Scattering transform statistics

scattering_stats computes wavelet scattering transform (WST) coefficients for ensembles of flat-sky patches. Scattering coefficients capture multi-scale non-Gaussian structure that is complementary to the power spectrum and higher-order moments — they are sensitive to phase correlations between scales in a way that collapsed moments are not.

Scattering coefficients

  • S1 (first-order): mean modulus of wavelet coefficients per scale and orientation. Captures the energy distribution across angular scales.

  • S2 (second-order): mean modulus of wavelet coefficients of modulus wavelet coefficients, i.e. cross-scale coupling. The S2/S1 ratio matrix identifies which scale pairs the DDPM fails to reproduce.

  • C11 (scattering covariance): covariance of first-order coefficients across realisations. Requires the Cheng et al. backend.

Installation

The module tries to import scattering (Cheng et al. 2021) first, then falls back to kymatio. The Cheng et al. implementation is preferred:

git clone https://github.com/SihaoCheng/scattering_transform.git
# copy scattering/ into the project root, or add to sys.path

# Or install kymatio as a fallback:
pip install kymatio

Usage

from foregrounds_diffusion.scattering_stats import (
    compute_scattering_coefficients,
    scattering_summary,
)

# Compute S1 and S2 for a stack of maps
J = 4    # number of dyadic scales
L = 4    # number of orientations
s1, s2 = compute_scattering_coefficients(maps_nhw, J=J, L=L)
# s1: (N, J, L),   s2: (N, J, J, L)

# Flatten to a feature vector for comparison
features = scattering_summary(s1, s2)   # (N, J*L + J*J*L)

See 11 — Scattering Transforms for a full comparison of scattering coefficients between AGORA, DDPM, and Gaussian maps, including the S2/S1 ratio matrix.

API

Scattering transform statistics for flat-sky CMB foreground patches.

Wraps the scattering package from Cheng et al. (https://github.com/SihaoCheng/scattering_transform) to compute S1 (first-order) and S2 (second-order) scattering coefficients, and the scattering covariance, for ensembles of flat-sky patches.

These statistics capture non-Gaussian structure at multiple angular scales and orientations, providing a richer description of the CIB and tSZ fields than the power spectrum alone.

Installation

The scattering package is not on PyPI. Clone the repository and add it to your Python path:

git clone https://github.com/SihaoCheng/scattering_transform.git # Then either: # (a) copy the scattering/ folder into your project, or # (b) add to sys.path in your notebook: # import sys; sys.path.insert(0, ‘/path/to/scattering_transform’)

Or install kymatio as a pip-installable alternative (see notes below):

pip install kymatio

Notes

This module tries to import scattering (Cheng et al.) first. If not found it falls back to kymatio with a warning. The Cheng et al. implementation is preferred because it is faster for large batches and exposes the scattering covariance C11 directly.

foregrounds_diffusion.scattering_stats.compute_scattering_coefficients(patches_nhw, J=5, L=4, device=None)[source]

Compute first- and second-order scattering coefficients.

Uses the Cheng et al. scattering package if available, otherwise falls back to kymatio.

Parameters:
  • patches_nhw (ndarray, shape (N, H, W)) – Stack of flat-sky patches. All patches must have the same spatial dimensions H × W = 256 × 256.

  • J (int) – Number of dyadic scales (default 5, giving scales 2¹…2⁵ pixels).

  • L (int) – Number of orientations per scale (default 4).

  • device (str or None) – 'cuda', 'cpu', or None (auto-detect).

Returns:

dict with keys

'S0'ndarray, shape (N, 1)

Zeroth-order coefficient (mean pixel value).

'S1'ndarray, shape (N, J)

First-order scattering coefficients, orientation-averaged (S1_iso in Cheng et al. notation).

'S2'ndarray, shape (N, J, J, L)

Second-order scattering coefficients — cross-scale coupling at pairs (j1, j2) as a function of orientation difference l (S2_iso in Cheng et al. notation).

'S1_mean'ndarray, shape (J,)

Mean S1 across all N patches.

'S2_mean'ndarray, shape (J, J, L)

Mean S2 across all N patches.

'J', 'L'int

Parameters used.

foregrounds_diffusion.scattering_stats.compute_scattering_covariance(patches_nhw, J=5, L=4, device=None)[source]

Compute the scattering covariance C11 (modulus × modulus correlations).

Only available with the Cheng et al. backend. The scattering covariance captures cross-scale correlations that the scattering mean misses, making it sensitive to non-Gaussian structure at multiple scales.

Parameters:
  • patches_nhw (ndarray, shape (N, H, W)) – Stack of flat-sky patches.

  • J (int) – Number of dyadic scales.

  • L (int) – Number of orientations.

  • device (str or None) – 'cuda', 'cpu', or None (auto-detect).

Returns:

dict – Full scattering covariance dictionary from Cheng et al., with keys 'C11_iso', 'C01_iso', 'S1', etc. Returns None if the Cheng et al. backend is not available.

foregrounds_diffusion.scattering_stats.scattering_summary(coeffs, scale_idx=None)[source]

Extract a flat summary vector from scattering coefficients.

Useful for computing residuals or distances between ensembles.

Parameters:
  • coeffs (dict) – Output of compute_scattering_coefficients().

  • scale_idx (list of int, optional) – Subset of j indices (scales) to include. All scales used if None.

Returns:

ndarray, shape (N, n_features) – Flattened scattering features per map.