Source code for foregrounds_diffusion.stacking
import numpy as np
# ---------------------------------------------------------------------------
# tSZ cluster stacking utilities
# ---------------------------------------------------------------------------
[docs]
def select_snr_pixels(tsz_maps_nhw, snr_min, snr_max, min_separation=5):
"""Find local SNR-peak coordinates within a given SNR bin.
Parameters
----------
tsz_maps_nhw : ndarray, shape (N, H, W)
snr_min : float
Lower SNR bound (inclusive).
snr_max : float or None
Upper SNR bound (exclusive). *None* means no upper bound.
min_separation : int
Minimum pixel separation between selected peaks.
Returns
-------
list of tuple
Each element is ``(patch_idx, row, col)``.
"""
from scipy.ndimage import maximum_filter
coords = []
for i, m in enumerate(tsz_maps_nhw):
noise = m.std()
if noise == 0:
continue
snr_map = m / noise
local_max = maximum_filter(snr_map, size=min_separation) == snr_map
in_bin = (snr_map >= snr_min) & local_max
if snr_max is not None:
in_bin &= snr_map < snr_max
for ri, rj in np.argwhere(in_bin):
coords.append((i, int(ri), int(rj)))
print(f"SNR [{snr_min}, {snr_max}): {len(coords)} peaks found")
return coords
[docs]
def extract_cutouts(maps_nhw, coords, cutout_size, max_cutouts=500):
"""Extract square cutouts centred on ``(patch_idx, row, col)`` coordinates.
Parameters
----------
maps_nhw : ndarray, shape (N, H, W)
coords : list of tuple
Coordinates from :func:`select_snr_pixels`.
cutout_size : int
Side length of the square cutout in pixels.
max_cutouts : int
Maximum number of cutouts to extract (caps memory use).
Returns
-------
ndarray, shape (M, cutout_size, cutout_size) or None
Stacked cutouts, or *None* if no valid cutouts were found.
"""
half = cutout_size // 2
cutouts = []
for patch_idx, ri, rj in coords[:max_cutouts]:
m = maps_nhw[patch_idx]
ri0, ri1 = ri - half, ri + half
rj0, rj1 = rj - half, rj + half
if ri0 < 0 or ri1 > m.shape[0] or rj0 < 0 or rj1 > m.shape[1]:
continue
cutouts.append(m[ri0:ri1, rj0:rj1])
return np.array(cutouts, dtype=np.float32) if cutouts else None