{ "cells": [ { "cell_type": "markdown", "id": "9a9d4c02", "metadata": {}, "source": [ "# 11 — Scattering Transforms\n", "\n", "**Purpose:** Compute scattering transform coefficients for CIB and tSZ patches\n", "and compare Agora, DDPM, and Gaussian baseline distributions.\n", "\n", "The scattering transform is a non-linear multi-scale statistical descriptor that\n", "captures information complementary to the power spectrum. It applies a sequence\n", "of wavelet convolutions and modulus operations, producing coefficients that are\n", "sensitive to non-Gaussian structure at multiple angular scales and orientations.\n", "\n", "This notebook computes:\n", "- **S1** — first-order coefficients: mean wavelet modulus at each scale j and\n", " orientation l. These are related to the power spectrum.\n", "- **S2** — second-order coefficients: mean modulus of modulus at pairs (j1,l1),\n", " (j2,l2) with j2 > j1. These capture non-Gaussian cross-scale correlations.\n", "- **C11** (optional) — scattering covariance: full cross-scale correlation\n", " matrix. Requires the Cheng et al. backend.\n", "\n", "## Installation\n", "\n", "**Option A — Cheng et al. (recommended, faster, exposes C11):**\n", "```bash\n", "cd ~/cmb_foregrounds_diffusion\n", "git clone https://github.com/SihaoCheng/scattering_transform.git\n", "```\n", "Then add to sys.path in this notebook (see cell below).\n", "\n", "**Option B — kymatio (pip-installable, slower, no C11):**\n", "```bash\n", "pip install kymatio\n", "```\n", "\n", "**Inputs:**\n", "- Test maps: `data/low_pass/2mJy/*.npy`\n", "- DDPM samples: `data/low_pass/2mJy/new_samples_*.npy`\n", "- Norm params: `data/low_pass/2mJy/norm_params_2mJy.npy`\n", "\n", "**Outputs:** `plots/scattering_coefficients.pdf`\n", "\n", "**Key module functions:**\n", "- `foregrounds_diffusion.scattering_stats.compute_scattering_coefficients`\n", "- `foregrounds_diffusion.scattering_stats.compute_scattering_covariance`" ] }, { "cell_type": "markdown", "id": "95afffae", "metadata": {}, "source": [ "# Check for the Cheng et al. scattering backend and add it to sys.path.\n", "# This must be done before importing scattering_stats, which tries to import\n", "# the backend at module level and falls back to kymatio if not found.\n", "## 1 Backend and configuration\n", "\n", "The scattering transform requires one of two backends:\n", "\n", "**Cheng et al. (recommended)**: clone\n", "`https://github.com/SihaoCheng/scattering_transform` alongside the project\n", "and set `SCATTERING_REPO` to point to it. This backend is significantly\n", "faster than kymatio on CPU and supports the `C11` cross-scale covariance\n", "statistic (§4.3 of Cheng et al. 2020).\n", "\n", "**kymatio (fallback)**: `pip install kymatio`. Slower on CPU but works\n", "without any path configuration.\n", "\n", "Key parameters:\n", "- **J** — number of dyadic scales, j = 1…J. Scale j corresponds to wavelets\n", " of support ≈ 2ʲ pixels. At 1.41′/px, J = 6 covers 2.8′ → 90′.\n", "- **L** — number of orientations per scale. L = 8 gives 22.5° angular\n", " resolution, sufficient to detect preferred orientations in cluster filaments." ] }, { "cell_type": "code", "execution_count": null, "id": "156e003f", "metadata": {}, "outputs": [], "source": [ "import sys\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "from pathlib import Path\n", "\n", "# If using Cheng et al. backend, add the cloned repo to the path\n", "SCATTERING_REPO = Path('/home/apb86/cmb_foregrounds_diffusion/scattering_transform')\n", "if SCATTERING_REPO.exists():\n", " sys.path.insert(0, str(SCATTERING_REPO))\n", " print('Using Cheng et al. scattering backend')\n", "else:\n", " print('Cheng et al. repo not found — will try kymatio')\n", " print(f'To install: git clone https://github.com/SihaoCheng/scattering_transform.git {SCATTERING_REPO}')\n", "\n", "from foregrounds_diffusion.scattering_stats import (\n", " compute_scattering_coefficients,\n", " compute_scattering_covariance,\n", " scattering_summary,\n", ")\n", "\n", "PTSRC = 2\n", "PIXEL_RES_ARCMIN = 1.40625\n", "N_MAPS = 100 # number of maps per source\n", "J = 5 # number of scales: 2^1 … 2^5 pixels\n", "L = 4 # number of orientations\n", "\n", "PATCHES_DIR = Path(f'data/low_pass/{PTSRC}mJy')\n", "PROJECT_ROOT = Path('/home/apb86/cmb_foregrounds_diffusion')\n" ] }, { "cell_type": "markdown", "id": "ac83108e", "metadata": {}, "source": [ "# Load normalised arrays (same split and denormalisation as notebooks 06-10).\n", "# Use float32 throughout to halve memory and avoid kymatio's implicit f64 cast.\n", "## 2 Load and denormalise maps\n", "\n", "Same loading pattern as notebooks 06–10. Scattering coefficients are\n", "dimensionless relative quantities, so the absolute amplitude of the\n", "denormalisation is less critical than for power spectra. Using physical µK\n", "units ensures that the ratio DDPM/Agora is directly interpretable as a\n", "fractional amplitude error." ] }, { "cell_type": "code", "execution_count": null, "id": "eec854f3", "metadata": {}, "outputs": [], "source": [ "# ---------------------------------------------------------------------------\n", "# Load and denormalise maps\n", "# ---------------------------------------------------------------------------\n", "norm_params = np.load(PATCHES_DIR / f'norm_params_{PTSRC}mJy.npy')\n", "cib_mean, cib_std, tsz_mean, tsz_std = norm_params\n", "\n", "cib_maps = np.load(PATCHES_DIR / f'CIB_map_150GHz_256_st6_zscore_{PTSRC}mJy_lp.npy')\n", "tsz_maps = np.load(PATCHES_DIR / f'tSZ3_map_150GHz_256_st6_zscore_{PTSRC}mJy_lp.npy')\n", "\n", "rng = np.random.default_rng(seed=42)\n", "indices = rng.permutation(len(cib_maps))\n", "test_idx = indices[int(0.8 * len(cib_maps)):]\n", "\n", "cib_test = (cib_maps[test_idx, :, :, 0] * cib_std + cib_mean).astype(np.float32)\n", "tsz_test = (tsz_maps[test_idx, :, :, 0] * tsz_std + tsz_mean).astype(np.float32)\n", "print(f'Test patches: {len(cib_test)}')\n", "\n", "ddpm_raw = np.load(\n", " PROJECT_ROOT / 'data' / 'low_pass' / f'{PTSRC}mJy' /\n", " f'new_samples_cib_tsz_{PTSRC}mJy_zero_norm_6x6_w_au_lp.npy'\n", ")\n", "ddpm_cib = (ddpm_raw[:, 0] * cib_std + cib_mean).astype(np.float32)\n", "ddpm_tsz = (ddpm_raw[:, 1] * tsz_std + tsz_mean).astype(np.float32)\n", "\n", "gauss_maps = np.load(PATCHES_DIR / f'gaussian_cib_tsz_{PTSRC}mJy_lp.npy')\n", "gauss_cib = (gauss_maps[:, 0] if gauss_maps.shape[1] == 2\n", " else gauss_maps[:, :, :, 0]).astype(np.float32)\n", "gauss_tsz = (gauss_maps[:, 1] if gauss_maps.shape[1] == 2\n", " else gauss_maps[:, :, :, 1]).astype(np.float32)\n", "\n", "N = min(N_MAPS, len(cib_test), len(ddpm_cib), len(gauss_cib))\n", "print(f'Using {N} maps per source')\n" ] }, { "cell_type": "markdown", "id": "74681c7e", "metadata": {}, "source": [ "# compute_scattering_coefficients(maps, J, L):\n", "# maps : (N, H, W) float32 array in µK\n", "# J : number of dyadic scales (wavelet support ≈ 2^j pixels per scale)\n", "# L : number of orientations (angular resolution = 180° / L)\n", "# Returns dict with keys 'S0', 'S1' (N, J, L), 'S2' (N, J, J, L)\n", "## 3 Compute first- and second-order scattering coefficients\n", "\n", "`compute_scattering_coefficients` returns a dict:\n", "```python\n", "{\n", " 'S0': ndarray(N,), # global mean (not usually used)\n", " 'S1': ndarray(N, J, L), # first-order: ⟨|W_j,l * f|⟩ (mean wavelet modulus)\n", " 'S2': ndarray(N, J, J, L), # second-order: ⟨|W_j2,l2 * |W_j1,l1 * f||⟩\n", "}\n", "```\n", "S1 captures the energy at each scale-orientation pair — closely related to\n", "the power spectrum but orientation-resolved. S2 captures cross-scale coupling\n", "(how structures at scale j1 modulate structures at scale j2 > j1) and is\n", "sensitive to non-Gaussian filamentary structure invisible to the power spectrum.\n", "\n", "Runtime: ≈ 2–5 min per source on a single CPU core for 100 maps at 256×256." ] }, { "cell_type": "code", "execution_count": null, "id": "7a59253c", "metadata": {}, "outputs": [], "source": [ "# ---------------------------------------------------------------------------\n", "# Compute scattering coefficients\n", "# NOTE: ~2-5 min per source on CPU; much faster on GPU\n", "# ---------------------------------------------------------------------------\n", "print('Computing Agora CIB scattering coefficients...')\n", "s_agora_cib = compute_scattering_coefficients(cib_test[:N], J=J, L=L)\n", "\n", "print('Computing Agora tSZ scattering coefficients...')\n", "s_agora_tsz = compute_scattering_coefficients(tsz_test[:N], J=J, L=L)\n", "\n", "print('Computing DDPM CIB scattering coefficients...')\n", "s_ddpm_cib = compute_scattering_coefficients(ddpm_cib[:N], J=J, L=L)\n", "\n", "print('Computing DDPM tSZ scattering coefficients...')\n", "s_ddpm_tsz = compute_scattering_coefficients(ddpm_tsz[:N], J=J, L=L)\n", "\n", "print('Computing Gaussian CIB scattering coefficients...')\n", "s_gauss_cib = compute_scattering_coefficients(gauss_cib[:N], J=J, L=L)\n", "\n", "print('Computing Gaussian tSZ scattering coefficients...')\n", "s_gauss_tsz = compute_scattering_coefficients(gauss_tsz[:N], J=J, L=L)\n", "\n", "print('Done.')\n" ] }, { "cell_type": "markdown", "id": "c2321ebe", "metadata": {}, "source": [ "## 4 First-order coefficients — power vs scale\n", "\n", "S1 averaged over orientations gives a single value per scale j: the mean\n", "modulus of the wavelet transform. This is closely related to (but not\n", "identical to) the bandpass-filtered variance used in notebook 07. Plot S1_j\n", "vs scale 2ʲ (in arcmin) for Agora, DDPM, and Gaussian, separately for CIB\n", "and tSZ. The Gaussian baseline should match Agora at this order if the power\n", "spectra agree — departures indicate normalisation errors." ] }, { "cell_type": "code", "execution_count": null, "id": "6ede3110", "metadata": {}, "outputs": [], "source": [ "# ---------------------------------------------------------------------------\n", "# Plot S1 coefficients averaged over orientations\n", "# S1[j] = mean wavelet modulus at scale 2^j — related to power spectrum\n", "# S1 shape: (N, J) — already orientation-averaged (S1_iso from Cheng et al.)\n", "# ---------------------------------------------------------------------------\n", "scales = np.arange(1, J + 1) # j = 1 … J\n", "scale_labels = [f'$2^{j}$ px' for j in scales]\n", "\n", "fig, axes = plt.subplots(1, 2, figsize=(12, 5))\n", "\n", "sources_s1 = [\n", " ('Agora', 'C0', '-', s_agora_cib, s_agora_tsz),\n", " ('DDPM', 'C1', '-', s_ddpm_cib, s_ddpm_tsz),\n", " ('Gaussian', 'C2', '--', s_gauss_cib, s_gauss_tsz),\n", "]\n", "\n", "for ax, title, idx in zip(axes, ['CIB', 'tSZ'], [2, 3]):\n", " for label, color, ls, s_cib, s_tsz in sources_s1:\n", " s = s_cib if title == 'CIB' else s_tsz\n", " # S1 shape: (N, J) — no orientation axis to average over\n", " s1 = s['S1'] # (N, J)\n", " mean = s1.mean(axis=0) # (J,)\n", " std = s1.std(axis=0)\n", " ax.plot(scales, mean, color=color, ls=ls, lw=1.5, label=label, marker='o')\n", " ax.fill_between(scales, mean - std, mean + std, alpha=0.2, color=color)\n", " ax.set_xticks(scales)\n", " ax.set_xticklabels(scale_labels)\n", " ax.set_xlabel('Wavelet scale j')\n", " ax.set_ylabel(r'$S_1(j)$ (orientation-averaged)')\n", " ax.set_title(f'{title} — First-order scattering coefficients S1')\n", " ax.legend()\n", " ax.set_yscale('log')\n", "\n", "plt.tight_layout()\n", "Path('plots').mkdir(exist_ok=True)\n", "plt.savefig('plots/scattering_S1.pdf', dpi=200, bbox_inches='tight')\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "016a6864", "metadata": {}, "source": [ "## 5 Second-order coefficients — cross-scale coupling matrix\n", "\n", "S2_{j1,j2} (averaged over orientations) forms a (J × J) upper-triangular\n", "matrix (only j2 > j1 is defined). The ratio DDPM/Agora reveals whether the\n", "DDPM has correctly learnt the inter-scale modulation structure. Values close\n", "to 1 (light green in the ratio plot) indicate good agreement; values < 0.5\n", "or > 2 indicate systematic over- or under-production of cross-scale\n", "non-Gaussianity." ] }, { "cell_type": "code", "execution_count": null, "id": "5d385d4b", "metadata": {}, "outputs": [], "source": [ "# ---------------------------------------------------------------------------\n", "# Plot S2 coefficients — cross-scale coupling matrix\n", "# S2 shape: (N, J, J, L) — axes are (map, j1, j2, orientation_difference)\n", "# Average over N and orientation difference to get a (J, J) matrix\n", "# ---------------------------------------------------------------------------\n", "fig, axes = plt.subplots(1, 2, figsize=(14, 6))\n", "\n", "for ax, title, s_agora, s_ddpm in zip(\n", " axes,\n", " ['CIB', 'tSZ'],\n", " [s_agora_cib, s_agora_tsz],\n", " [s_ddpm_cib, s_ddpm_tsz],\n", "):\n", " # S2 shape: (N, J, J, L) → average over N and L → (J, J)\n", " s2_agora = s_agora['S2'].mean(axis=(0, 3)) # (J, J)\n", " s2_ddpm = s_ddpm['S2'].mean(axis=(0, 3))\n", "\n", " # Ratio DDPM/Agora (only upper triangle where j2>j1 is defined)\n", " with np.errstate(divide='ignore', invalid='ignore'):\n", " ratio = np.where(s2_agora > 0, s2_ddpm / s2_agora, np.nan)\n", "\n", " im = ax.imshow(ratio, origin='lower', cmap='RdBu_r',\n", " vmin=0.5, vmax=1.5, aspect='auto')\n", " ax.set_xticks(range(J)); ax.set_yticks(range(J))\n", " ax.set_xticklabels([f'j={j+1}' for j in range(J)])\n", " ax.set_yticklabels([f'j={j+1}' for j in range(J)])\n", " ax.set_xlabel('Scale j2'); ax.set_ylabel('Scale j1')\n", " ax.set_title(f'{title} — S2 ratio DDPM/Agora\\n(1=perfect agreement)')\n", " plt.colorbar(im, ax=ax)\n", "\n", "plt.tight_layout()\n", "plt.savefig('plots/scattering_S2_ratio.pdf', dpi=200, bbox_inches='tight')\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "bae42390", "metadata": {}, "source": [ "## 6 Summary feature vector and standardised residuals\n", "\n", "`scattering_summary` concatenates the orientation-averaged S1 and S2\n", "coefficients into a single 1D feature vector of length J + J(J−1)/2 per map.\n", "The standardised residual (Agora − DDPM) / σ_Agora per feature component\n", "provides a bar chart diagnostic: components with |residual| > 2 identify\n", "specific scale combinations where the DDPM statistics deviate from Agora\n", "beyond the Agora sample variance." ] }, { "cell_type": "code", "execution_count": null, "id": "4b0afa68", "metadata": {}, "outputs": [], "source": [ "# ---------------------------------------------------------------------------\n", "# Summary: flattened scattering feature vector residuals\n", "# ---------------------------------------------------------------------------\n", "feat_agora_cib = scattering_summary(s_agora_cib) # (N, n_features)\n", "feat_ddpm_cib = scattering_summary(s_ddpm_cib)\n", "feat_agora_tsz = scattering_summary(s_agora_tsz)\n", "feat_ddpm_tsz = scattering_summary(s_ddpm_tsz)\n", "\n", "resid_cib = ((feat_agora_cib.mean(0) - feat_ddpm_cib.mean(0))\n", " / (feat_agora_cib.std(0) + 1e-10))\n", "resid_tsz = ((feat_agora_tsz.mean(0) - feat_ddpm_tsz.mean(0))\n", " / (feat_agora_tsz.std(0) + 1e-10))\n", "\n", "fig, axes = plt.subplots(1, 2, figsize=(12, 4))\n", "for ax, resid, title in zip(axes, [resid_cib, resid_tsz], ['CIB', 'tSZ']):\n", " ax.bar(range(len(resid)), resid, color='C0', alpha=0.7)\n", " ax.axhline(0, color='k', lw=0.8, ls='--')\n", " ax.axhline(1, color='gray', lw=0.6, ls=':')\n", " ax.axhline(-1, color='gray', lw=0.6, ls=':')\n", " ax.set_xlabel('Scattering feature index')\n", " ax.set_ylabel(r'$(\\bar{S}^{\\rm Agora} - \\bar{S}^{\\rm DDPM}) / \\sigma^{\\rm Agora}$')\n", " ax.set_title(f'{title} scattering coefficient residuals')\n", " ax.set_ylim(-3, 3)\n", "\n", "plt.tight_layout()\n", "plt.savefig('plots/scattering_residuals.pdf', dpi=200, bbox_inches='tight')\n", "plt.show()\n", "\n", "print(f'CIB: {(np.abs(resid_cib) < 1).mean():.1%} of features within 1σ')\n", "print(f'tSZ: {(np.abs(resid_tsz) < 1).mean():.1%} of features within 1σ')\n" ] }, { "cell_type": "markdown", "id": "0157bd4b", "metadata": {}, "source": [ "## 7 Optional: scattering covariance C11 (Cheng et al. backend)\n", "\n", "The `C11` (cross-wavelet covariance) statistic measures how the wavelet\n", "modulus at scale j1 covaries with the wavelet coefficients at scale j2 across\n", "orientations. It is defined in Cheng et al. (2020) and requires the Cheng\n", "backend; kymatio does not support it. `C11_iso` is the isotropic projection\n", "averaged over orientation differences, giving a `(J, J, J)` tensor. The\n", "mean absolute value of `C11_iso` visualised as a (J, J) heatmap shows which\n", "scale pairs have the strongest cross-covariance." ] }, { "cell_type": "code", "execution_count": null, "id": "65e76d9a", "metadata": {}, "outputs": [], "source": [ "# ---------------------------------------------------------------------------\n", "# Optional: scattering covariance C11 (Cheng et al. backend only)\n", "# ---------------------------------------------------------------------------\n", "print('Computing scattering covariance C11 for Agora CIB...')\n", "cov_agora_cib = compute_scattering_covariance(cib_test[:N], J=J, L=L)\n", "\n", "if cov_agora_cib is not None:\n", " print(f'Available keys: {list(cov_agora_cib.keys())}')\n", " c11_iso = cov_agora_cib['C11_iso'] # shape (N, J, J, J, L, L)\n", " print(f'C11_iso shape: {c11_iso.shape}')\n", "\n", " # Mean C11_iso over orientations and maps\n", " c11_mean = np.nanmean(np.abs(c11_iso), axis=(0, 4, 5)) # (J, J, J)\n", "\n", " fig, ax = plt.subplots(figsize=(6, 5))\n", " im = ax.imshow(c11_mean.mean(axis=0), origin='lower', cmap='viridis')\n", " ax.set_xlabel('Scale j2'); ax.set_ylabel('Scale j3')\n", " ax.set_title('Agora CIB — mean |C11_iso| (j1-averaged)')\n", " plt.colorbar(im, ax=ax)\n", " plt.tight_layout()\n", " plt.savefig('plots/scattering_C11.pdf', dpi=200, bbox_inches='tight')\n", " plt.show()\n", "else:\n", " print('Scattering covariance not available (requires Cheng et al. backend).')\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.11.13" } }, "nbformat": 4, "nbformat_minor": 5 }