Source code for foregrounds_diffusion.preprocessing

"""Data preprocessing utilities: normalisation, patch extraction, and filtering.

This module covers everything between raw HEALPix FITS files and the
training-ready ``.npy`` arrays consumed by :mod:`foregrounds_diffusion.train`.
The full pipeline is documented in ``docs/tutorials/02_masking.ipynb`` and
``docs/tutorials/03_patch_extraction.ipynb``.

Normalisation
-------------
Two schemes are used, one per channel:

- **CIB** (``apply_maxmin_normalization``): min–max scaled to [0, 1].
  CIB intensities are non-negative by construction, so this is lossless.
- **tSZ** (``apply_stdnorm``): z-score normalised to zero mean and unit
  variance per channel across the full training set.

The inverse operations are :func:`renormalize_dm_maps` (for post-sampling
rescaling back to physical units) and :func:`denormalize_dm_maps` (for
z-score inversion).

Patch extraction
----------------
:class:`FlatCutter` projects HEALPix full-sky maps onto flat-sky patches
using a gnomonic (tangent-plane) projection.  :func:`get_patch_centers`
returns a grid of pointing centres that tiles the sky.

Filtering
---------
:func:`get_lpf_hpf` builds low-pass, high-pass, or band-pass 2D filters in
ℓ-space.  The training data uses a low-pass cut at ℓ = 7000 to avoid
aliasing artefacts from the HEALPix pixel window function.

:func:`wiener_filter` builds a Wiener filter given signal and noise spectra.

Dataset splitting
-----------------
:func:`split_data_to_tensors` performs an 80/10/10 train/val/test split,
transposes from channels-last ``(N, H, W, C)`` to channels-first
``(N, C, H, W)``, and returns ``torch.Tensor`` objects ready for the
``DataLoader``.
"""

import astropy.units as u
import healpy as hp
import numpy as np
import torch

from foregrounds_diffusion.flatmaps import cl_to_cl2d, get_lxly

# ---------------------------------------------------------------------------
# Normalisation utilities
# ---------------------------------------------------------------------------


[docs] def apply_maxmin_normalization(maps): """Normalise an array to [0, 1] using global min–max scaling. Parameters ---------- maps : ndarray Input array of any shape. NaN values are ignored when computing the range. Returns ------- ndarray Normalised copy with values in [0, 1]. Returns an array of zeros if ``max − min == 0`` (constant input). """ min_val = np.nanmin(maps) max_val = np.nanmax(maps) denom = max_val - min_val if denom == 0: return np.zeros_like(maps) return (maps - min_val) / denom
[docs] def apply_stdnorm(maps): """Z-score normalise each channel of a channels-last array independently. Parameters ---------- maps : ndarray, shape (..., C) Array with channels in the last axis. Returns ------- ndarray, shape (..., C) Copy with each channel shifted to zero mean and unit variance. Channels with zero standard deviation are set to zero. """ maps = maps.copy() for c in range(maps.shape[-1]): channel = maps[..., c] std = np.std(channel) if std == 0: maps[..., c] = 0.0 else: maps[..., c] = (channel - np.mean(channel)) / std return maps
[docs] def renormalize_dm_maps(dm_maps, train_maps, variance_scaling=True): """Rescale diffusion-model output maps to match training-set statistics. Parameters ---------- dm_maps : ndarray, shape (N, C, H, W) Raw diffusion-model samples (channels-first). train_maps : ndarray, shape (N, H, W, C) Reference training maps (channels-last). variance_scaling : bool If *True*, also match the per-channel standard deviation. Returns ------- ndarray, shape (N, C, H, W) Renormalised maps in channels-first layout. """ dm_maps = np.transpose(dm_maps, (0, 2, 3, 1)).copy() num_channels = train_maps.shape[-1] for i in range(num_channels): tr_min = np.min(train_maps[:, :, :, i]) tr_max = np.max(train_maps[:, :, :, i]) dm_maps[:, :, :, i] = dm_maps[:, :, :, i] * (tr_max - tr_min) + tr_min if variance_scaling: dm_mean = np.mean(dm_maps[:, :, :, i]) dm_std = np.std(dm_maps[:, :, :, i]) tr_mean = np.mean(train_maps[:, :, :, i]) tr_std = np.std(train_maps[:, :, :, i]) if dm_std == 0: # Constant channel: shift the mean only, skip the (tr_std/dm_std) # scaling that would otherwise divide by zero and produce NaNs. dm_maps[:, :, :, i] = dm_maps[:, :, :, i] - dm_mean + tr_mean else: dm_maps[:, :, :, i] = (dm_maps[:, :, :, i] - dm_mean) * (tr_std / dm_std) + tr_mean return np.transpose(dm_maps, (0, 3, 1, 2))
[docs] def rescale_samples(samples, cib_factor=1.0, tsz_factor=1.0): """Apply per-channel scalar rescaling to DDPM samples (paper §3.2). Prabhu et al. correct the DDPM's slight under-dispersion by multiplying each generated channel by a single global factor: the ratio of the AGORA sample standard deviation to the generated sample standard deviation. This is a pure scalar multiply, distinct from the affine :func:`renormalize_dm_maps` (see inconsistency #4 in ``docs/paper_code_inconsistencies.md``). Parameters ---------- samples : ndarray, shape (N, C, H, W) Raw DDPM samples, channels-first (channel 0 = CIB, channel 1 = tSZ). cib_factor, tsz_factor : float Multiplicative factors for the CIB (channel 0) and tSZ (channel 1) channels. ``1.0`` leaves a channel unchanged. Prabhu et al. cite 1.0328 (CIB) and 1.1425 (tSZ) for their checkpoint; for a different checkpoint, measure ``std(AGORA) / std(generated)`` per channel. Returns ------- ndarray, shape (N, C, H, W) Rescaled copy. The input is not modified. """ out = samples.copy() out[:, 0] *= cib_factor if out.shape[1] > 1: out[:, 1] *= tsz_factor return out
[docs] def denormalize_dm_maps(dm_maps, cib_mean, cib_std, tsz_mean, tsz_std): """Invert Z-score normalisation applied during patch extraction. Parameters ---------- dm_maps : ndarray, shape (N, 2, H, W) Raw DDPM samples in Z-score space (channels-first). cib_mean, cib_std : float Z-score parameters for the CIB channel (channel 0). tsz_mean, tsz_std : float Z-score parameters for the tSZ channel (channel 1). Returns ------- ndarray, shape (N, 2, H, W) Denormalised maps in the same physical units as the training patches. """ dm_maps = dm_maps.copy() dm_maps[:, 0] = dm_maps[:, 0] * cib_std + cib_mean dm_maps[:, 1] = dm_maps[:, 1] * tsz_std + tsz_mean return dm_maps
# --------------------------------------------------------------------------- # Moments loading # ---------------------------------------------------------------------------
[docs] def load_all_moments(filename, bandpass_centers, max_lines=-1): """Load and normalise scattering moments from a .npy file. Parameters ---------- filename : str Path to the .npy moments array with shape (N, L, 12). bandpass_centers : array_like Bandpass centre values used for normalisation. max_lines : int Number of realisations to load. *-1* loads all. Returns ------- dict Dictionary keyed ``"moment_00"`` … ``"moment_11"``, each value being a list of normalised moment arrays. """ _raw = np.load(filename) moments_data = _raw if max_lines == -1 else _raw[:max_lines] norms = [ bandpass_centers, # S2aa bandpass_centers, # S2bb bandpass_centers, # S2ab bandpass_centers, # S3aaa bandpass_centers, # S3bbb bandpass_centers, # S3aab bandpass_centers, # S3abb bandpass_centers**2, # S4aaaa bandpass_centers**2, # S4bbbb bandpass_centers**2, # S4aaab bandpass_centers**2, # S4aabb bandpass_centers**2, # S4abbb ] moments = {} for i in range(12): label = f"moment_{i:02d}" moments[label] = [m / norms[i] for m in moments_data[:, :, i]] return moments
# --------------------------------------------------------------------------- # Patch-centre computation and HEALPix patch extraction # --------------------------------------------------------------------------- @u.quantity_input def get_patch_centers(gal_cut: u.deg, step_size: u.deg, pole_cut: u.deg): """Compute patch centres on the sky, avoiding the Galactic plane. Parameters ---------- gal_cut : `~astropy.units.Quantity` Half-width of the Galactic-plane exclusion zone in degrees. step_size : `~astropy.units.Quantity` Stepping distance in Galactic latitude in degrees. Returns ------- list of tuple Each element is ``(lon, lat)`` as `~astropy.units.Quantity` in degrees. """ gal_cut = gal_cut.to(u.deg) step_size = step_size.to(u.deg) pole_cut = pole_cut.to(u.deg) southern = ( np.arange(-90 + pole_cut.value / 2, (-gal_cut - step_size).value, step_size.value) * u.deg ) northern = ( np.arange((gal_cut + step_size).value, 90 - pole_cut.value / 2, step_size.value) * u.deg ) lat_range = np.concatenate((southern, northern)) centers = [] for t in lat_range: step = step_size.value / np.cos(t.to(u.rad).value) for i in np.arange(0, 360, step): centers.append((i * u.deg, t)) return centers class FlatCutter: """Extract flat-sky patches from a HEALPix map by rotation and interpolation. Parameters ---------- ang_x, ang_y : `~astropy.units.Quantity` Angular extent of the patch in the x and y directions. xres, yres : int Number of pixels in x and y. """ @u.quantity_input def __init__(self, ang_x: u.deg, ang_y: u.deg, xres: int, yres: int): self.xres = xres self.yres = yres self.ang_x = ang_x self.ang_y = ang_y self.xarr = np.linspace( -self.ang_x.to(u.rad).value / 2.0, self.ang_x.to(u.rad).value / 2.0, xres ) self.yarr = np.linspace( -self.ang_y.to(u.rad).value / 2.0, self.ang_y.to(u.rad).value / 2.0, yres ) xgrid, ygrid = np.meshgrid(self.xarr, self.yarr) xgrid = xgrid.ravel()[None, :] ygrid = ygrid.ravel()[None, :] zgrid = np.ones_like(ygrid) self.vecs = np.concatenate((xgrid, ygrid, zgrid)).T self.lons, self.lats = hp.vec2ang(self.vecs, lonlat=True) self.lats *= u.deg self.lons *= u.deg @u.quantity_input def rotate_to_pole_and_interpolate(self, lon: u.deg, lat: u.deg, ma): """Rotate the patch grid to *(lon, lat)* and sample the map. Parameters ---------- lon, lat : `~astropy.units.Quantity` Sky position of the patch centre. ma : ndarray or list of ndarray HEALPix map(s) to sample. Returns ------- ndarray, shape (xres, yres) or (xres, yres, nmaps) Interpolated flat-sky patch(es). """ if hp.pixelfunc.maptype(ma) == 0: ma = [ma] rotator = hp.Rotator(rot=[lon.to(u.deg).value, lat.to(u.deg).value - 90.0], deg=True) self.inv_lon_grid, self.inv_lat_grid = rotator.I( self.lons.to(u.deg).value, self.lats.to(u.deg).value, lonlat=True ) m_rot = [ hp.get_interp_val(each, self.inv_lon_grid, self.inv_lat_grid, lonlat=True) for each in ma ] if len(m_rot) > 1: m_rot[-2], m_rot[-1] = _spin2rot( m_rot[-2], m_rot[-1], rotator.angle_ref(self.inv_lon_grid, self.inv_lat_grid, lonlat=True), ) m_rot[-2], m_rot[-1] = _spin2rot(m_rot[-2], m_rot[-1], self.lons.to(u.rad).value) else: m_rot = m_rot[0] return np.moveaxis(np.array(m_rot).reshape(-1, self.xres, self.yres), 0, -1) def _spin2rot(q, u, angle): """Rotate spin-2 field (Q, U) by *angle* (internal helper).""" c, s = np.cos(2 * angle), np.sin(2 * angle) return c * q - s * u, s * q + c * u # --------------------------------------------------------------------------- # HEALPix map utilities # ---------------------------------------------------------------------------
[docs] def replace_zeros_with_neighbor_avg(hp_map): """Replace zero pixels in a HEALPix map with the average of non-zero neighbours. Parameters ---------- hp_map : ndarray 1D HEALPix map array. Returns ------- ndarray Modified map with zero pixels filled. """ nside = hp.get_nside(hp_map) zeros_indices = np.where(hp_map == 0)[0] for idx in zeros_indices: neighbors = hp.get_all_neighbours(nside, idx) valid = neighbors[(neighbors >= 0) & (hp_map[neighbors] != 0)] hp_map[idx] = np.mean(hp_map[valid]) if len(valid) > 0 else 0 return hp_map
# --------------------------------------------------------------------------- # Fourier-space filtering # ---------------------------------------------------------------------------
[docs] def get_lpf_hpf(flatskymapparams, lmin_lmax, filter_type=0): """Build a 2D Fourier filter (low-pass, high-pass, or band-pass). Parameters ---------- flatskymapparams : list [nx, ny, dx, dy] — see :func:`~foregrounds_diffusion.flatmaps.get_lxly`. lmin_lmax : float or tuple of float Cutoff multipole (scalar) or (lmin, lmax) for band-pass. filter_type : int 0 → low-pass, 1 → high-pass, 2 → band-pass. Returns ------- ndarray 2D binary filter array. """ lx, ly = get_lxly(flatskymapparams) ell = np.sqrt(lx**2.0 + ly**2.0) fft_filter = np.ones(ell.shape) if filter_type == 0: fft_filter[ell > lmin_lmax] = 0.0 elif filter_type == 1: fft_filter[ell < lmin_lmax] = 0.0 elif filter_type == 2: lmin, lmax = lmin_lmax fft_filter[ell < lmin] = 0.0 fft_filter[ell > lmax] = 0.0 return fft_filter
[docs] def wiener_filter(mapparams, cl_signal, cl_noise, el=None): """Compute a 2D Wiener filter from signal and noise power spectra. Parameters ---------- mapparams : list [nx, ny, dx, dy] — see :func:`~foregrounds_diffusion.flatmaps.get_lxly`. cl_signal, cl_noise : array_like 1D signal and noise power spectra. el : array_like, optional Multipoles. Defaults to ``np.arange(len(cl_signal))``. Returns ------- ndarray 2D Wiener filter. """ if el is None: el = np.arange(len(cl_signal)) cl_signal2d = cl_to_cl2d(el, cl_signal, mapparams) cl_noise2d = cl_to_cl2d(el, cl_noise, mapparams) return cl_signal2d / (cl_signal2d + cl_noise2d)
# --------------------------------------------------------------------------- # Dataset splitting # ---------------------------------------------------------------------------
[docs] def split_data_to_tensors(data, train_size=0.7, val_size=0.15, test_size=0.15, seed=42): """Split a numpy array into train/val/test PyTorch tensors. Parameters ---------- data : ndarray, shape (N, H, W, C) Input data in channels-last layout. train_size, val_size, test_size : float Fractional split sizes (must sum to 1). seed : int Random seed for reproducibility. Returns ------- train_set, val_set, test_set : torch.Tensor Tensors in channels-first layout (N, C, H, W). """ if not np.isclose(train_size + val_size + test_size, 1.0): raise ValueError("train_size + val_size + test_size must equal 1.") rng = np.random.default_rng(seed) indices = np.arange(data.shape[0]) rng.shuffle(indices) train_end = int(train_size * len(indices)) val_end = train_end + int(val_size * len(indices)) train_set = torch.tensor(data[indices[:train_end]].transpose(0, 3, 1, 2), dtype=torch.float32) val_set = torch.tensor( data[indices[train_end:val_end]].transpose(0, 3, 1, 2), dtype=torch.float32 ) test_set = torch.tensor(data[indices[val_end:]].transpose(0, 3, 1, 2), dtype=torch.float32) return train_set, val_set, test_set
# --------------------------------------------------------------------------- # Data augmentation # ---------------------------------------------------------------------------
[docs] def augment_images_unique(images): """Apply 8× augmentation: 4 rotations × horizontal flip. Parameters ---------- images : torch.Tensor, shape (N, C, H, W) Training images in channels-first layout. Returns ------- torch.Tensor, shape (8N, C, H, W) Augmented images (each original appears as 8 distinct variants). """ augmented = [] for img in images: for k in range(4): rotated = torch.rot90(img, k=k, dims=(1, 2)) augmented.append(rotated) augmented.append(torch.flip(rotated, dims=[2])) return torch.stack(augmented)