05 — Sampling and Post-Processing¶
Purpose: Generate CIB–tSZ map pairs from a trained checkpoint and apply post-sampling variance rescaling.
This notebook demonstrates the full generation workflow:
Loading a checkpoint — rebuilds the
Unet+GaussianDiffusionmodel with the same architecture as training and loads weights from a.ptcheckpoint via theload_checkpointfunction insample.py.Generating samples — runs the reverse diffusion process to produce batches of
(N, 2, 256, 256)CIB–tSZ patch pairs. Raw outputs are in the model’s internal normalised range.Variance rescaling — applies post-sampling correction to recover the true pixel intensity scale. The paper describes multiplying by a single scalar factor (σ_Agora / σ_DDPM: 1.0328 for CIB, 1.1425 for tSZ); the codebase implements this via
renormalize_dm_maps, which applies a two-step affine transform. Both approaches are shown and compared. Seedocs/paper_code_inconsistencies.mdfor details.
Inputs:
Trained checkpoint:
results/model-20.ptTraining maps (for rescaling statistics):
data/low_pass/2mJy/*.npy
Outputs:
Generated samples:
data/low_pass/2mJy/new_samples_cib_tsz_2mJy_lp.npy
Key module functions:
foregrounds_diffusion.sample.build_modelforegrounds_diffusion.sample.load_checkpointforegrounds_diffusion.sample.sampleforegrounds_diffusion.preprocessing.renormalize_dm_maps
Paper reference: §3.2 (variance rescaling), §4 (generated sample evaluation).
1 Setup¶
Load the trained checkpoint and prepare the diffusion model for sampling. The accelerator object handles device placement and mixed-precision automatically, so the same sampling code works on a single GPU, multiple GPUs (via accelerate launch), or CPU. Checkpoints are saved under results/<run_name>/ and named model-{step}.pt.
[1]:
# build_model() constructs the U-Net + GaussianDiffusion wrapper with the
# same hyperparameters as training. load_checkpoint() restores the EMA
# weights (not the raw gradient weights) — EMA produces sharper, less noisy
# samples than the instantaneous model weights.
import numpy as np
import torch
from accelerate import Accelerator
from pathlib import Path
from foregrounds_diffusion.sample import build_model, load_checkpoint, sample
from foregrounds_diffusion.preprocessing import renormalize_dm_maps
PROJECT_ROOT = Path("/home/apb86/cmb_foregrounds_diffusion")
PATCHES_DIR = PROJECT_ROOT / "data" / "low_pass" / f"{PTSRC}mJy"
CHECKPOINT = PROJECT_ROOT / "results" / "model-20.pt"
OUTPUT_PATH = PROJECT_ROOT / "data" / "low_pass" / "2mJy" / "new_samples_cib_tsz_2mJy_lp.npy"
N_BATCHES = 5
BATCH_SIZE = 16
PTSRC = 2
2 Load checkpoint¶
build_model constructs the U-Net + GaussianDiffusion wrapper with the same hyperparameters used at training time (dim = 64, dim_mults = (1,2,4,8), channels = 2, T = 1000). load_checkpoint restores the EMA weights — not the raw U-Net weights — because the EMA model consistently produces higher visual quality than the instantaneous weights.
[2]:
# build_model() constructs the U-Net + GaussianDiffusion wrapper with the
# same hyperparameters as training. load_checkpoint() restores the EMA
# weights (not the raw gradient weights) — EMA produces sharper, less noisy
# samples than the instantaneous model weights.
accelerator = Accelerator(split_batches=True, mixed_precision='fp16')
print(f"Device: {accelerator.device}")
diffusion = build_model(channels=2)
diffusion = diffusion.to(accelerator.device)
diffusion = load_checkpoint(diffusion, CHECKPOINT, accelerator)
print("Checkpoint loaded.")
Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
Device: cpu
Checkpoint loaded.
3 Reverse diffusion sampling¶
sample runs the full reverse Markov chain: starting from a standard Gaussian noise tensor (batch_size, 2, 256, 256), the model iterates T = 1000 DDPM denoising steps to produce correlated CIB–tSZ patch pairs. All num_batches batches are concatenated along axis 0 to give the final sample array of shape (N_total, 2, 256, 256) in normalised space (CIB ∈ [0,1], tSZ ~ N(0,1)).
[3]:
# Reverse diffusion: T steps of denoising from N(0,I) → map
all_samples = sample(diffusion, accelerator,
num_batches=N_BATCHES, batch_size=BATCH_SIZE)
print(f"Raw sample shape (channels-first): {all_samples.shape}")
print(f"Raw value range: [{all_samples.min():.4f}, {all_samples.max():.4f}]")
Sampling batch 1/5 (0% complete)
---------------------------------------------------------------------------
KeyboardInterrupt Traceback (most recent call last)
Cell In[3], line 2
1 # Reverse diffusion: T steps of denoising from N(0,I) → map
----> 2 all_samples = sample(diffusion, accelerator,
3 num_batches=N_BATCHES, batch_size=BATCH_SIZE)
4 print(f"Raw sample shape (channels-first): {all_samples.shape}")
5 print(f"Raw value range: [{all_samples.min():.4f}, {all_samples.max():.4f}]")
File ~/cmb_foregrounds_diffusion/foregrounds_diffusion/sample.py:107, in sample(diffusion, accelerator, num_batches, batch_size)
103 print(f"Sampling batch {i + 1}/{num_batches} "
104 f"({(i / num_batches) * 100:.0f}% complete)")
105 with torch.no_grad():
106 batch = accelerator.gather(
--> 107 diffusion.sample(batch_size=batch_size).detach())
108 sampled_batches.append(batch.cpu().numpy())
110 return np.concatenate(sampled_batches, axis=0)
File ~/diffusion_project_env/lib64/python3.11/site-packages/torch/utils/_contextlib.py:124, in context_decorator.<locals>.decorate_context(*args, **kwargs)
120 @functools.wraps(func)
121 def decorate_context(*args, **kwargs):
122 # pyrefly: ignore [bad-context-manager]
123 with ctx_factory():
--> 124 return func(*args, **kwargs)
File ~/diffusion_project_env/lib64/python3.11/site-packages/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py:747, in GaussianDiffusion.sample(self, batch_size, return_all_timesteps)
745 (h, w), channels = self.image_size, self.channels
746 sample_fn = self.p_sample_loop if not self.is_ddim_sampling else self.ddim_sample
--> 747 return sample_fn((batch_size, channels, h, w), return_all_timesteps = return_all_timesteps)
File ~/diffusion_project_env/lib64/python3.11/site-packages/torch/utils/_contextlib.py:124, in context_decorator.<locals>.decorate_context(*args, **kwargs)
120 @functools.wraps(func)
121 def decorate_context(*args, **kwargs):
122 # pyrefly: ignore [bad-context-manager]
123 with ctx_factory():
--> 124 return func(*args, **kwargs)
File ~/diffusion_project_env/lib64/python3.11/site-packages/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py:693, in GaussianDiffusion.p_sample_loop(self, shape, return_all_timesteps)
691 for t in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps):
692 self_cond = x_start if self.self_condition else None
--> 693 img, x_start = self.p_sample(img, t, self_cond)
694 imgs.append(img)
696 ret = img if not return_all_timesteps else torch.stack(imgs, dim = 1)
File ~/diffusion_project_env/lib64/python3.11/site-packages/torch/utils/_contextlib.py:124, in context_decorator.<locals>.decorate_context(*args, **kwargs)
120 @functools.wraps(func)
121 def decorate_context(*args, **kwargs):
122 # pyrefly: ignore [bad-context-manager]
123 with ctx_factory():
--> 124 return func(*args, **kwargs)
File ~/diffusion_project_env/lib64/python3.11/site-packages/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py:677, in GaussianDiffusion.p_sample(self, x, t, x_self_cond)
675 b, *_, device = *x.shape, self.device
676 batched_times = torch.full((b,), t, device = device, dtype = torch.long)
--> 677 model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = batched_times, x_self_cond = x_self_cond, clip_denoised = True)
678 noise = torch.randn_like(x) if t > 0 else 0. # no noise if t == 0
679 pred_img = model_mean + (0.5 * model_log_variance).exp() * noise
File ~/diffusion_project_env/lib64/python3.11/site-packages/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py:664, in GaussianDiffusion.p_mean_variance(self, x, t, x_self_cond, clip_denoised)
663 def p_mean_variance(self, x, t, x_self_cond = None, clip_denoised = True):
--> 664 preds = self.model_predictions(x, t, x_self_cond)
665 x_start = preds.pred_x_start
667 if clip_denoised:
File ~/diffusion_project_env/lib64/python3.11/site-packages/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py:639, in GaussianDiffusion.model_predictions(self, x, t, x_self_cond, clip_x_start, rederive_pred_noise)
638 def model_predictions(self, x, t, x_self_cond = None, clip_x_start = False, rederive_pred_noise = False):
--> 639 model_output = self.model(x, t, x_self_cond)
640 maybe_clip = partial(torch.clamp, min = -1., max = 1.) if clip_x_start else identity
642 if self.objective == 'pred_noise':
File ~/diffusion_project_env/lib64/python3.11/site-packages/torch/nn/modules/module.py:1776, in Module._wrapped_call_impl(self, *args, **kwargs)
1774 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1775 else:
-> 1776 return self._call_impl(*args, **kwargs)
File ~/diffusion_project_env/lib64/python3.11/site-packages/torch/nn/modules/module.py:1787, in Module._call_impl(self, *args, **kwargs)
1782 # If we don't have any hooks, we want to skip the rest of the logic in
1783 # this function, and just call forward.
1784 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1785 or _global_backward_pre_hooks or _global_backward_hooks
1786 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1787 return forward_call(*args, **kwargs)
1789 result = None
1790 called_always_called_hooks = set()
File ~/diffusion_project_env/lib64/python3.11/site-packages/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py:411, in Unet.forward(self, x, time, x_self_cond)
408 h.append(x)
410 x = block2(x, t)
--> 411 x = attn(x) + x
412 h.append(x)
414 x = downsample(x)
File ~/diffusion_project_env/lib64/python3.11/site-packages/torch/nn/modules/module.py:1776, in Module._wrapped_call_impl(self, *args, **kwargs)
1774 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1775 else:
-> 1776 return self._call_impl(*args, **kwargs)
File ~/diffusion_project_env/lib64/python3.11/site-packages/torch/nn/modules/module.py:1787, in Module._call_impl(self, *args, **kwargs)
1782 # If we don't have any hooks, we want to skip the rest of the logic in
1783 # this function, and just call forward.
1784 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1785 or _global_backward_pre_hooks or _global_backward_hooks
1786 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1787 return forward_call(*args, **kwargs)
1789 result = None
1790 called_always_called_hooks = set()
File ~/diffusion_project_env/lib64/python3.11/site-packages/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py:227, in LinearAttention.forward(self, x)
224 mk, mv = map(lambda t: repeat(t, 'h c n -> b h c n', b = b), self.mem_kv)
225 k, v = map(partial(torch.cat, dim = -1), ((mk, k), (mv, v)))
--> 227 q = q.softmax(dim = -2)
228 k = k.softmax(dim = -1)
230 q = q * self.scale
KeyboardInterrupt:
4 Post-sampling variance rescaling¶
DDPM samples often have slightly compressed variance relative to the training distribution — a known consequence of the EMA smoothing and the mean-squared error training objective. We correct for this by computing the per-channel pixel standard deviation over the training set and rescaling each sample channel to match. The rescaling is applied in normalised space before inverting the normalisation to physical units (µK).
[ ]:
# Load training maps for rescaling statistics
cib_maps = np.load(PATCHES_DIR / f"CIB_map_150GHz_256_st6_minmax_{PTSRC}mJy_zero_lp.npy")
tsz_maps = np.load(PATCHES_DIR / f"tSZ3_map_150GHz_256_st6_minmax_{PTSRC}mJy_norm_lp.npy")
train_maps = np.concatenate([cib_maps, tsz_maps], axis=-1) # (N, H, W, 2)
# --- Code implementation: two-step affine transform ---
rescaled_code = renormalize_dm_maps(all_samples, train_maps, variance_scaling=True)
print(f"Rescaled (code) shape: {rescaled_code.shape}")
# --- Paper description: scalar multiply by sigma_Agora / sigma_DDPM ---
PAPER_SCALES = {'CIB': 1.0328, 'tSZ': 1.1425} # from paper §3.2
rescaled_paper = all_samples.copy()
for c, (name, scale) in enumerate(PAPER_SCALES.items()):
rescaled_paper[:, c] *= scale
print(f"{name} scale factor (paper): {scale}")
# See docs/paper_code_inconsistencies.md for discussion of the difference
5 Save samples¶
Save the rescaled samples as a .npy array of shape (N, 2, H, W) (channels first). Downstream notebooks (06–12) load this file and denormalise to physical units using the norm_params file saved during patch extraction.
[ ]:
from pathlib import Path
Path(OUTPUT_PATH).parent.mkdir(parents=True, exist_ok=True)
np.save(OUTPUT_PATH, rescaled_code)
print(f"Saved {rescaled_code.shape[0]} samples → {OUTPUT_PATH}")