{ "cells": [ { "cell_type": "markdown", "id": "fc4f8d31", "metadata": {}, "source": [ "# 05 — Sampling and Post-Processing\n", "\n", "**Purpose:** Generate CIB–tSZ map pairs from a trained checkpoint and apply\n", "post-sampling variance rescaling.\n", "\n", "This notebook demonstrates the full generation workflow:\n", "\n", "1. **Loading a checkpoint** — rebuilds the `Unet` + `GaussianDiffusion` model with the\n", " same architecture as training and loads weights from a `.pt` checkpoint via the\n", " `load_checkpoint` function in `sample.py`.\n", "\n", "2. **Generating samples** — runs the reverse diffusion process to produce batches of\n", " `(N, 2, 256, 256)` CIB–tSZ patch pairs. Raw outputs are in the model's internal\n", " normalised range.\n", "\n", "3. **Variance rescaling** — applies post-sampling correction to recover the true pixel\n", " intensity scale. The paper describes multiplying by a single scalar factor\n", " (σ_Agora / σ_DDPM: 1.0328 for CIB, 1.1425 for tSZ); the codebase implements this\n", " via `renormalize_dm_maps`, which applies a two-step affine transform. Both approaches\n", " are shown and compared. See `docs/paper_code_inconsistencies.md` for details.\n", "\n", "**Inputs:**\n", "- Trained checkpoint: `results/model-20.pt`\n", "- Training maps (for rescaling statistics): `data/low_pass/2mJy/*.npy`\n", "\n", "**Outputs:**\n", "- Generated samples: `data/low_pass/2mJy/new_samples_cib_tsz_2mJy_lp.npy`\n", "\n", "**Key module functions:**\n", "- `foregrounds_diffusion.sample.build_model`\n", "- `foregrounds_diffusion.sample.load_checkpoint`\n", "- `foregrounds_diffusion.sample.sample`\n", "- `foregrounds_diffusion.preprocessing.renormalize_dm_maps`\n", "\n", "**Paper reference:** §3.2 (variance rescaling), §4 (generated sample evaluation)." ] }, { "cell_type": "markdown", "id": "377fff8a", "metadata": {}, "source": [ "## 1 Setup\n", "\n", "Load the trained checkpoint and prepare the diffusion model for sampling. The\n", "`accelerator` object handles device placement and mixed-precision automatically,\n", "so the same sampling code works on a single GPU, multiple GPUs (via\n", "`accelerate launch`), or CPU. Checkpoints are saved under `results//`\n", "and named `model-{step}.pt`." ] }, { "cell_type": "code", "execution_count": 1, "id": "51470a05", "metadata": {}, "outputs": [], "source": [ "# build_model() constructs the U-Net + GaussianDiffusion wrapper with the\n", "# same hyperparameters as training. load_checkpoint() restores the EMA\n", "# weights (not the raw gradient weights) — EMA produces sharper, less noisy\n", "# samples than the instantaneous model weights.\n", "import numpy as np\n", "import torch\n", "from accelerate import Accelerator\n", "from pathlib import Path\n", "\n", "from foregrounds_diffusion.sample import build_model, load_checkpoint, sample\n", "from foregrounds_diffusion.preprocessing import renormalize_dm_maps\n", "\n", "PROJECT_ROOT = Path(\"/home/apb86/cmb_foregrounds_diffusion\")\n", "PATCHES_DIR = PROJECT_ROOT / \"data\" / \"low_pass\" / f\"{PTSRC}mJy\"\n", "\n", "CHECKPOINT = PROJECT_ROOT / \"results\" / \"model-20.pt\"\n", "OUTPUT_PATH = PROJECT_ROOT / \"data\" / \"low_pass\" / \"2mJy\" / \"new_samples_cib_tsz_2mJy_lp.npy\"\n", "N_BATCHES = 5\n", "BATCH_SIZE = 16\n", "PTSRC = 2\n" ] }, { "cell_type": "markdown", "id": "ccaef440", "metadata": {}, "source": [ "## 2 Load checkpoint\n", "\n", "`build_model` constructs the U-Net + `GaussianDiffusion` wrapper with the same\n", "hyperparameters used at training time (dim = 64, dim_mults = (1,2,4,8),\n", "channels = 2, T = 1000). `load_checkpoint` restores the EMA weights — not\n", "the raw U-Net weights — because the EMA model consistently produces higher\n", "visual quality than the instantaneous weights." ] }, { "cell_type": "code", "execution_count": 2, "id": "03cf34e8", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Device: cpu\n", "Checkpoint loaded.\n" ] } ], "source": [ "# build_model() constructs the U-Net + GaussianDiffusion wrapper with the\n", "# same hyperparameters as training. load_checkpoint() restores the EMA\n", "# weights (not the raw gradient weights) — EMA produces sharper, less noisy\n", "# samples than the instantaneous model weights.\n", "accelerator = Accelerator(split_batches=True, mixed_precision='fp16')\n", "print(f\"Device: {accelerator.device}\")\n", "\n", "diffusion = build_model(channels=2)\n", "diffusion = diffusion.to(accelerator.device)\n", "diffusion = load_checkpoint(diffusion, CHECKPOINT, accelerator)\n", "print(\"Checkpoint loaded.\")\n" ] }, { "cell_type": "markdown", "id": "4fde8872", "metadata": {}, "source": [ "## 3 Reverse diffusion sampling\n", "\n", "`sample` runs the full reverse Markov chain: starting from a standard Gaussian\n", "noise tensor `(batch_size, 2, 256, 256)`, the model iterates T = 1000 DDPM\n", "denoising steps to produce correlated CIB–tSZ patch pairs. All `num_batches`\n", "batches are concatenated along axis 0 to give the final sample array of shape\n", "`(N_total, 2, 256, 256)` in normalised space (CIB ∈ [0,1], tSZ ~ N(0,1))." ] }, { "cell_type": "code", "execution_count": 3, "id": "d4d6f490", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Sampling batch 1/5 (0% complete)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "345ec6f400784764a1c54fd7c6ec3cd0", "version_major": 2, "version_minor": 0 }, "text/plain": [ "sampling loop time step: 0%| | 0/1000 [00:00 \u001b[39m\u001b[32m2\u001b[39m all_samples = \u001b[43msample\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdiffusion\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maccelerator\u001b[49m\u001b[43m,\u001b[49m\n\u001b[32m 3\u001b[39m \u001b[43m \u001b[49m\u001b[43mnum_batches\u001b[49m\u001b[43m=\u001b[49m\u001b[43mN_BATCHES\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[43m=\u001b[49m\u001b[43mBATCH_SIZE\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 4\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mRaw sample shape (channels-first): \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mall_samples.shape\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m)\n\u001b[32m 5\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mRaw value range: [\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mall_samples.min()\u001b[38;5;132;01m:\u001b[39;00m\u001b[33m.4f\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m, \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mall_samples.max()\u001b[38;5;132;01m:\u001b[39;00m\u001b[33m.4f\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m]\u001b[39m\u001b[33m\"\u001b[39m)\n", "\u001b[36mFile \u001b[39m\u001b[32m~/cmb_foregrounds_diffusion/foregrounds_diffusion/sample.py:107\u001b[39m, in \u001b[36msample\u001b[39m\u001b[34m(diffusion, accelerator, num_batches, batch_size)\u001b[39m\n\u001b[32m 103\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mSampling batch \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mi\u001b[38;5;250m \u001b[39m+\u001b[38;5;250m \u001b[39m\u001b[32m1\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m/\u001b[39m\u001b[38;5;132;01m{\u001b[39;00mnum_batches\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m \u001b[39m\u001b[33m\"\u001b[39m\n\u001b[32m 104\u001b[39m \u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33m(\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m(i\u001b[38;5;250m \u001b[39m/\u001b[38;5;250m \u001b[39mnum_batches)\u001b[38;5;250m \u001b[39m*\u001b[38;5;250m \u001b[39m\u001b[32m100\u001b[39m\u001b[38;5;132;01m:\u001b[39;00m\u001b[33m.0f\u001b[39m\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m% complete)\u001b[39m\u001b[33m\"\u001b[39m)\n\u001b[32m 105\u001b[39m \u001b[38;5;28;01mwith\u001b[39;00m torch.no_grad():\n\u001b[32m 106\u001b[39m batch = accelerator.gather(\n\u001b[32m--> \u001b[39m\u001b[32m107\u001b[39m \u001b[43mdiffusion\u001b[49m\u001b[43m.\u001b[49m\u001b[43msample\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[43m=\u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[43m)\u001b[49m.detach())\n\u001b[32m 108\u001b[39m sampled_batches.append(batch.cpu().numpy())\n\u001b[32m 110\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m np.concatenate(sampled_batches, axis=\u001b[32m0\u001b[39m)\n", "\u001b[36mFile \u001b[39m\u001b[32m~/diffusion_project_env/lib64/python3.11/site-packages/torch/utils/_contextlib.py:124\u001b[39m, in \u001b[36mcontext_decorator..decorate_context\u001b[39m\u001b[34m(*args, **kwargs)\u001b[39m\n\u001b[32m 120\u001b[39m \u001b[38;5;129m@functools\u001b[39m.wraps(func)\n\u001b[32m 121\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mdecorate_context\u001b[39m(*args, **kwargs):\n\u001b[32m 122\u001b[39m \u001b[38;5;66;03m# pyrefly: ignore [bad-context-manager]\u001b[39;00m\n\u001b[32m 123\u001b[39m \u001b[38;5;28;01mwith\u001b[39;00m ctx_factory():\n\u001b[32m--> \u001b[39m\u001b[32m124\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", "\u001b[36mFile \u001b[39m\u001b[32m~/diffusion_project_env/lib64/python3.11/site-packages/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py:747\u001b[39m, in \u001b[36mGaussianDiffusion.sample\u001b[39m\u001b[34m(self, batch_size, return_all_timesteps)\u001b[39m\n\u001b[32m 745\u001b[39m (h, w), channels = \u001b[38;5;28mself\u001b[39m.image_size, \u001b[38;5;28mself\u001b[39m.channels\n\u001b[32m 746\u001b[39m sample_fn = \u001b[38;5;28mself\u001b[39m.p_sample_loop \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m.is_ddim_sampling \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28mself\u001b[39m.ddim_sample\n\u001b[32m--> \u001b[39m\u001b[32m747\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43msample_fn\u001b[49m\u001b[43m(\u001b[49m\u001b[43m(\u001b[49m\u001b[43mbatch_size\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mchannels\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mh\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mw\u001b[49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mreturn_all_timesteps\u001b[49m\u001b[43m \u001b[49m\u001b[43m=\u001b[49m\u001b[43m \u001b[49m\u001b[43mreturn_all_timesteps\u001b[49m\u001b[43m)\u001b[49m\n", "\u001b[36mFile \u001b[39m\u001b[32m~/diffusion_project_env/lib64/python3.11/site-packages/torch/utils/_contextlib.py:124\u001b[39m, in \u001b[36mcontext_decorator..decorate_context\u001b[39m\u001b[34m(*args, **kwargs)\u001b[39m\n\u001b[32m 120\u001b[39m \u001b[38;5;129m@functools\u001b[39m.wraps(func)\n\u001b[32m 121\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mdecorate_context\u001b[39m(*args, **kwargs):\n\u001b[32m 122\u001b[39m \u001b[38;5;66;03m# pyrefly: ignore [bad-context-manager]\u001b[39;00m\n\u001b[32m 123\u001b[39m \u001b[38;5;28;01mwith\u001b[39;00m ctx_factory():\n\u001b[32m--> \u001b[39m\u001b[32m124\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", "\u001b[36mFile \u001b[39m\u001b[32m~/diffusion_project_env/lib64/python3.11/site-packages/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py:693\u001b[39m, in \u001b[36mGaussianDiffusion.p_sample_loop\u001b[39m\u001b[34m(self, shape, return_all_timesteps)\u001b[39m\n\u001b[32m 691\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m t \u001b[38;5;129;01min\u001b[39;00m tqdm(\u001b[38;5;28mreversed\u001b[39m(\u001b[38;5;28mrange\u001b[39m(\u001b[32m0\u001b[39m, \u001b[38;5;28mself\u001b[39m.num_timesteps)), desc = \u001b[33m'\u001b[39m\u001b[33msampling loop time step\u001b[39m\u001b[33m'\u001b[39m, total = \u001b[38;5;28mself\u001b[39m.num_timesteps):\n\u001b[32m 692\u001b[39m self_cond = x_start \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m.self_condition \u001b[38;5;28;01melse\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m--> \u001b[39m\u001b[32m693\u001b[39m img, x_start = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mp_sample\u001b[49m\u001b[43m(\u001b[49m\u001b[43mimg\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mt\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mself_cond\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 694\u001b[39m imgs.append(img)\n\u001b[32m 696\u001b[39m ret = img \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m return_all_timesteps \u001b[38;5;28;01melse\u001b[39;00m torch.stack(imgs, dim = \u001b[32m1\u001b[39m)\n", "\u001b[36mFile \u001b[39m\u001b[32m~/diffusion_project_env/lib64/python3.11/site-packages/torch/utils/_contextlib.py:124\u001b[39m, in \u001b[36mcontext_decorator..decorate_context\u001b[39m\u001b[34m(*args, **kwargs)\u001b[39m\n\u001b[32m 120\u001b[39m \u001b[38;5;129m@functools\u001b[39m.wraps(func)\n\u001b[32m 121\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mdecorate_context\u001b[39m(*args, **kwargs):\n\u001b[32m 122\u001b[39m \u001b[38;5;66;03m# pyrefly: ignore [bad-context-manager]\u001b[39;00m\n\u001b[32m 123\u001b[39m \u001b[38;5;28;01mwith\u001b[39;00m ctx_factory():\n\u001b[32m--> \u001b[39m\u001b[32m124\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", "\u001b[36mFile \u001b[39m\u001b[32m~/diffusion_project_env/lib64/python3.11/site-packages/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py:677\u001b[39m, in \u001b[36mGaussianDiffusion.p_sample\u001b[39m\u001b[34m(self, x, t, x_self_cond)\u001b[39m\n\u001b[32m 675\u001b[39m b, *_, device = *x.shape, \u001b[38;5;28mself\u001b[39m.device\n\u001b[32m 676\u001b[39m batched_times = torch.full((b,), t, device = device, dtype = torch.long)\n\u001b[32m--> \u001b[39m\u001b[32m677\u001b[39m model_mean, _, model_log_variance, x_start = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mp_mean_variance\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m \u001b[49m\u001b[43m=\u001b[49m\u001b[43m \u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mt\u001b[49m\u001b[43m \u001b[49m\u001b[43m=\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatched_times\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mx_self_cond\u001b[49m\u001b[43m \u001b[49m\u001b[43m=\u001b[49m\u001b[43m \u001b[49m\u001b[43mx_self_cond\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mclip_denoised\u001b[49m\u001b[43m \u001b[49m\u001b[43m=\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n\u001b[32m 678\u001b[39m noise = torch.randn_like(x) \u001b[38;5;28;01mif\u001b[39;00m t > \u001b[32m0\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m \u001b[32m0.\u001b[39m \u001b[38;5;66;03m# no noise if t == 0\u001b[39;00m\n\u001b[32m 679\u001b[39m pred_img = model_mean + (\u001b[32m0.5\u001b[39m * model_log_variance).exp() * noise\n", "\u001b[36mFile \u001b[39m\u001b[32m~/diffusion_project_env/lib64/python3.11/site-packages/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py:664\u001b[39m, in \u001b[36mGaussianDiffusion.p_mean_variance\u001b[39m\u001b[34m(self, x, t, x_self_cond, clip_denoised)\u001b[39m\n\u001b[32m 663\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mp_mean_variance\u001b[39m(\u001b[38;5;28mself\u001b[39m, x, t, x_self_cond = \u001b[38;5;28;01mNone\u001b[39;00m, clip_denoised = \u001b[38;5;28;01mTrue\u001b[39;00m):\n\u001b[32m--> \u001b[39m\u001b[32m664\u001b[39m preds = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mmodel_predictions\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mt\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mx_self_cond\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 665\u001b[39m x_start = preds.pred_x_start\n\u001b[32m 667\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m clip_denoised:\n", "\u001b[36mFile \u001b[39m\u001b[32m~/diffusion_project_env/lib64/python3.11/site-packages/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py:639\u001b[39m, in \u001b[36mGaussianDiffusion.model_predictions\u001b[39m\u001b[34m(self, x, t, x_self_cond, clip_x_start, rederive_pred_noise)\u001b[39m\n\u001b[32m 638\u001b[39m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34mmodel_predictions\u001b[39m(\u001b[38;5;28mself\u001b[39m, x, t, x_self_cond = \u001b[38;5;28;01mNone\u001b[39;00m, clip_x_start = \u001b[38;5;28;01mFalse\u001b[39;00m, rederive_pred_noise = \u001b[38;5;28;01mFalse\u001b[39;00m):\n\u001b[32m--> \u001b[39m\u001b[32m639\u001b[39m model_output = \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mt\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mx_self_cond\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 640\u001b[39m maybe_clip = partial(torch.clamp, \u001b[38;5;28mmin\u001b[39m = -\u001b[32m1.\u001b[39m, \u001b[38;5;28mmax\u001b[39m = \u001b[32m1.\u001b[39m) \u001b[38;5;28;01mif\u001b[39;00m clip_x_start \u001b[38;5;28;01melse\u001b[39;00m identity\n\u001b[32m 642\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m.objective == \u001b[33m'\u001b[39m\u001b[33mpred_noise\u001b[39m\u001b[33m'\u001b[39m:\n", "\u001b[36mFile \u001b[39m\u001b[32m~/diffusion_project_env/lib64/python3.11/site-packages/torch/nn/modules/module.py:1776\u001b[39m, in \u001b[36mModule._wrapped_call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1774\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m._compiled_call_impl(*args, **kwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[32m 1775\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m1776\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", "\u001b[36mFile \u001b[39m\u001b[32m~/diffusion_project_env/lib64/python3.11/site-packages/torch/nn/modules/module.py:1787\u001b[39m, in \u001b[36mModule._call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1782\u001b[39m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[32m 1783\u001b[39m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[32m 1784\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m._backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_pre_hooks\n\u001b[32m 1785\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[32m 1786\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[32m-> \u001b[39m\u001b[32m1787\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 1789\u001b[39m result = \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m 1790\u001b[39m called_always_called_hooks = \u001b[38;5;28mset\u001b[39m()\n", "\u001b[36mFile \u001b[39m\u001b[32m~/diffusion_project_env/lib64/python3.11/site-packages/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py:411\u001b[39m, in \u001b[36mUnet.forward\u001b[39m\u001b[34m(self, x, time, x_self_cond)\u001b[39m\n\u001b[32m 408\u001b[39m h.append(x)\n\u001b[32m 410\u001b[39m x = block2(x, t)\n\u001b[32m--> \u001b[39m\u001b[32m411\u001b[39m x = \u001b[43mattn\u001b[49m\u001b[43m(\u001b[49m\u001b[43mx\u001b[49m\u001b[43m)\u001b[49m + x\n\u001b[32m 412\u001b[39m h.append(x)\n\u001b[32m 414\u001b[39m x = downsample(x)\n", "\u001b[36mFile \u001b[39m\u001b[32m~/diffusion_project_env/lib64/python3.11/site-packages/torch/nn/modules/module.py:1776\u001b[39m, in \u001b[36mModule._wrapped_call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1774\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m._compiled_call_impl(*args, **kwargs) \u001b[38;5;66;03m# type: ignore[misc]\u001b[39;00m\n\u001b[32m 1775\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m-> \u001b[39m\u001b[32m1776\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[43m.\u001b[49m\u001b[43m_call_impl\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n", "\u001b[36mFile \u001b[39m\u001b[32m~/diffusion_project_env/lib64/python3.11/site-packages/torch/nn/modules/module.py:1787\u001b[39m, in \u001b[36mModule._call_impl\u001b[39m\u001b[34m(self, *args, **kwargs)\u001b[39m\n\u001b[32m 1782\u001b[39m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[32m 1783\u001b[39m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[32m 1784\u001b[39m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m._backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m._forward_pre_hooks\n\u001b[32m 1785\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_backward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[32m 1786\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[32m-> \u001b[39m\u001b[32m1787\u001b[39m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m*\u001b[49m\u001b[43m*\u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[32m 1789\u001b[39m result = \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[32m 1790\u001b[39m called_always_called_hooks = \u001b[38;5;28mset\u001b[39m()\n", "\u001b[36mFile \u001b[39m\u001b[32m~/diffusion_project_env/lib64/python3.11/site-packages/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py:227\u001b[39m, in \u001b[36mLinearAttention.forward\u001b[39m\u001b[34m(self, x)\u001b[39m\n\u001b[32m 224\u001b[39m mk, mv = \u001b[38;5;28mmap\u001b[39m(\u001b[38;5;28;01mlambda\u001b[39;00m t: repeat(t, \u001b[33m'\u001b[39m\u001b[33mh c n -> b h c n\u001b[39m\u001b[33m'\u001b[39m, b = b), \u001b[38;5;28mself\u001b[39m.mem_kv)\n\u001b[32m 225\u001b[39m k, v = \u001b[38;5;28mmap\u001b[39m(partial(torch.cat, dim = -\u001b[32m1\u001b[39m), ((mk, k), (mv, v)))\n\u001b[32m--> \u001b[39m\u001b[32m227\u001b[39m q = \u001b[43mq\u001b[49m\u001b[43m.\u001b[49m\u001b[43msoftmax\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdim\u001b[49m\u001b[43m \u001b[49m\u001b[43m=\u001b[49m\u001b[43m \u001b[49m\u001b[43m-\u001b[49m\u001b[32;43m2\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[32m 228\u001b[39m k = k.softmax(dim = -\u001b[32m1\u001b[39m)\n\u001b[32m 230\u001b[39m q = q * \u001b[38;5;28mself\u001b[39m.scale\n", "\u001b[31mKeyboardInterrupt\u001b[39m: " ] } ], "source": [ "# Reverse diffusion: T steps of denoising from N(0,I) → map\n", "all_samples = sample(diffusion, accelerator,\n", " num_batches=N_BATCHES, batch_size=BATCH_SIZE)\n", "print(f\"Raw sample shape (channels-first): {all_samples.shape}\")\n", "print(f\"Raw value range: [{all_samples.min():.4f}, {all_samples.max():.4f}]\")\n" ] }, { "cell_type": "markdown", "id": "7743b443", "metadata": {}, "source": [ "## 4 Post-sampling variance rescaling\n", "\n", "DDPM samples often have slightly compressed variance relative to the training\n", "distribution — a known consequence of the EMA smoothing and the mean-squared\n", "error training objective. We correct for this by computing the per-channel\n", "pixel standard deviation over the training set and rescaling each sample\n", "channel to match. The rescaling is applied in normalised space before\n", "inverting the normalisation to physical units (µK)." ] }, { "cell_type": "code", "execution_count": null, "id": "51046ee1", "metadata": {}, "outputs": [], "source": [ "# Load training maps for rescaling statistics\n", "cib_maps = np.load(PATCHES_DIR / f\"CIB_map_150GHz_256_st6_minmax_{PTSRC}mJy_zero_lp.npy\")\n", "tsz_maps = np.load(PATCHES_DIR / f\"tSZ3_map_150GHz_256_st6_minmax_{PTSRC}mJy_norm_lp.npy\")\n", "train_maps = np.concatenate([cib_maps, tsz_maps], axis=-1) # (N, H, W, 2)\n", "\n", "# --- Code implementation: two-step affine transform ---\n", "rescaled_code = renormalize_dm_maps(all_samples, train_maps, variance_scaling=True)\n", "print(f\"Rescaled (code) shape: {rescaled_code.shape}\")\n", "\n", "# --- Paper description: scalar multiply by sigma_Agora / sigma_DDPM ---\n", "PAPER_SCALES = {'CIB': 1.0328, 'tSZ': 1.1425} # from paper §3.2\n", "rescaled_paper = all_samples.copy()\n", "for c, (name, scale) in enumerate(PAPER_SCALES.items()):\n", " rescaled_paper[:, c] *= scale\n", " print(f\"{name} scale factor (paper): {scale}\")\n", "\n", "# See docs/paper_code_inconsistencies.md for discussion of the difference\n" ] }, { "cell_type": "markdown", "id": "846d9f10", "metadata": {}, "source": [ "## 5 Save samples\n", "\n", "Save the rescaled samples as a `.npy` array of shape `(N, 2, H, W)` (channels\n", "first). Downstream notebooks (06–12) load this file and denormalise to\n", "physical units using the `norm_params` file saved during patch extraction." ] }, { "cell_type": "code", "execution_count": null, "id": "362c3918", "metadata": {}, "outputs": [], "source": [ "from pathlib import Path\n", "Path(OUTPUT_PATH).parent.mkdir(parents=True, exist_ok=True)\n", "np.save(OUTPUT_PATH, rescaled_code)\n", "print(f\"Saved {rescaled_code.shape[0]} samples → {OUTPUT_PATH}\")\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.13" } }, "nbformat": 4, "nbformat_minor": 5 }