{ "cells": [ { "cell_type": "markdown", "id": "b523ca98", "metadata": {}, "source": [ "# 04 — Model Architecture and Training\n", "\n", "**Purpose:** Walk through the U-Net DDPM architecture and training configuration.\n", "\n", "This notebook documents the model design and training setup without requiring a full\n", "training run. It covers:\n", "\n", "1. **Model definition** — instantiates the `Unet` with `dim=64`,\n", " `dim_mults=(1, 2, 4, 8)`, `channels=2`, and `flash_attn=True`, giving 35.7 M\n", " trainable parameters across four encoder/decoder stages. Wraps it in\n", " `GaussianDiffusion` with 1000 timesteps, sigmoid noise schedule, and v-prediction\n", " objective.\n", "\n", "2. **Data loading and augmentation** — loads the stacked CIB+tSZ `.npy` arrays,\n", " performs the 80/20 train/validation split, and applies the 8× augmentation\n", " (4 rotations × horizontal flip) via `augment_images_unique`.\n", "\n", "3. **Trainer configuration** — explains each `Trainer1D` hyperparameter: batch size 16,\n", " learning rate 1×10⁻⁴, 100,000 steps, gradient accumulation every 2 steps,\n", " EMA decay 0.995, mixed precision (fp16), checkpoint every 5,000 steps.\n", "\n", "To actually train, run `accelerate launch foregrounds_diffusion/train.py` from the repo\n", "root instead of executing this notebook end-to-end.\n", "\n", "**Inputs:**\n", "- CIB patches: `data/low_pass/2mJy/CIB_map_150GHz_256_st6_minmax_2mJy_zero_lp.npy`\n", "- tSZ patches: `data/low_pass/2mJy/tSZ3_map_150GHz_256_st6_minmax_2mJy_norm_lp.npy`\n", "\n", "**Outputs:** model graph, parameter count table (no checkpoint written).\n", "\n", "**Key module functions:** none — uses `denoising_diffusion_pytorch` directly.\n", "\n", "**Paper reference:** §3.1 (DDPM framework), Appendix A (Table 1 — architecture details)." ] }, { "cell_type": "markdown", "id": "d265bbfd", "metadata": {}, "source": [ "## 1 Setup\n", "\n", "GPU and precision configuration. Flash attention is available only on CUDA\n", "devices with compute capability ≥ 8.0 (A100, H100, RTX 3090+). On CPU or\n", "older GPUs, set `flash_attn=False` in the U-Net constructor. Mixed-precision\n", "(`fp16`) training halves memory use and typically gives a ≈ 2× throughput\n", "improvement on modern GPUs with negligible loss-curve difference." ] }, { "cell_type": "code", "execution_count": null, "id": "5d0e14a4", "metadata": {}, "outputs": [], "source": [ "# 8× augmentation via all elements of the dihedral group D4\n", "# (4 rotations × 2 flips, applied jointly to both channels).\n", "# Memory: augmented shape is (8*N_train, 2, 256, 256) ≈ 5 GB for N_train=1000.\n", "import numpy as np\n", "import torch\n", "from denoising_diffusion_pytorch import Unet, GaussianDiffusion, Trainer1D, Dataset1D\n", "\n", "PTSRC = 2\n", "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", "print(f\"Device: {device}\")\n", "\n", "from foregrounds_diffusion.preprocessing import augment_images_unique\n" ] }, { "cell_type": "markdown", "id": "82834011", "metadata": {}, "source": [ "## 2 U-Net and diffusion process\n", "\n", "The denoiser is a 2D U-Net with four resolution levels: 64 → 128 → 256 → 512\n", "feature channels. The two-channel input `(CIB, tSZ)` is processed jointly so\n", "that cross-channel spatial correlations are learnt by every attention and\n", "convolution layer.\n", "\n", "`GaussianDiffusion` wraps the U-Net with a **sigmoid noise schedule** over\n", "1000 timesteps (T = 1000) — the default in `denoising-diffusion-pytorch`\n", "v2.2.5 (see also `docs/paper_code_inconsistencies.md` §noise schedule).\n", "The sigmoid schedule concentrates diffusion steps near t = 0 and t = T,\n", "where the signal-to-noise ratio changes most rapidly, and is combined with\n", "a v-prediction objective for improved sample stability." ] }, { "cell_type": "code", "execution_count": 2, "id": "befd8d01", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Trainable parameters: 35,708,290 (35.7 M)\n" ] } ], "source": [ "# U-Net architecture parameters.\n", "# dim=64 : base channel width (doubles at each downsampling level)\n", "# dim_mults : (1,2,4,8) → channel widths 64, 128, 256, 512\n", "# channels=2 : CIB + tSZ processed jointly (not independently)\n", "# flash_attn : use memory-efficient attention (requires CUDA + compute ≥ 8.0)\n", "#\n", "# GaussianDiffusion wraps the U-Net with a sigmoid noise schedule\n", "# (denoising-diffusion-pytorch default; see docs/paper_code_inconsistencies.md).\n", "# image_size and timesteps must match the values used during training.\n", "unet = Unet(\n", " dim=64,\n", " dim_mults=(1, 2, 4, 8),\n", " channels=2, # CIB + tSZ\n", " flash_attn=True,\n", ")\n", "\n", "diffusion = GaussianDiffusion(\n", " unet,\n", " image_size=256,\n", " timesteps=1000, # T = 1000 diffusion steps\n", ")\n", "diffusion = diffusion.to(device)\n", "\n", "# Parameter count (paper reports 35.7 M)\n", "total_params = sum(p.numel() for p in unet.parameters() if p.requires_grad)\n", "print(f\"Trainable parameters: {total_params:,} ({total_params / 1e6:.1f} M)\")\n" ] }, { "cell_type": "markdown", "id": "5ce0660a", "metadata": {}, "source": [ "## 3 Load and split training data\n", "\n", "Load the normalised `.npy` arrays, stack CIB and tSZ into a single\n", "`(N, H, W, 2)` channels-last tensor, then call `split_data_to_tensors` to\n", "produce train / val / test splits (80 / 10 / 10 by default, seeded at 42).\n", "The function transposes to PyTorch channels-first `(N, 2, H, W)` internally." ] }, { "cell_type": "code", "execution_count": 3, "id": "b36225ff", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Stacked patches: (674, 2, 256, 256)\n", "Train: 539, Val: 135\n" ] } ], "source": [ "from pathlib import Path\n", "PATCHES_DIR = Path(f\"data/low_pass/{PTSRC}mJy\")\n", "fpath_cib = PATCHES_DIR / f\"CIB_map_150GHz_256_st6_minmax_{PTSRC}mJy_zero_lp.npy\"\n", "fpath_tsz = PATCHES_DIR / f\"tSZ3_map_150GHz_256_st6_minmax_{PTSRC}mJy_norm_lp.npy\"\n", "\n", "cib_maps = np.load(fpath_cib) # (N, H, W, 1)\n", "tsz_maps = np.load(fpath_tsz) # (N, H, W, 1)\n", "cut_maps = np.concatenate([cib_maps, tsz_maps], axis=-1) # (N, H, W, 2)\n", "cut_maps = cut_maps.transpose(0, 3, 1, 2) # (N, 2, H, W)\n", "print(f\"Stacked patches: {cut_maps.shape}\")\n", "\n", "# 80 / 20 train / validation split (seeded for reproducibility)\n", "rng = np.random.default_rng(seed=42)\n", "indices = rng.permutation(len(cut_maps))\n", "num_train = int(0.8 * len(cut_maps))\n", "training_images = torch.tensor(cut_maps[indices[:num_train]], dtype=torch.float32)\n", "val_images = torch.tensor(cut_maps[indices[num_train:]], dtype=torch.float32)\n", "print(f\"Train: {len(training_images)}, Val: {len(val_images)}\")\n" ] }, { "cell_type": "markdown", "id": "29356677", "metadata": {}, "source": [ "## 4 Data augmentation\n", "\n", "`augment_images_unique` applies all 8 elements of the dihedral group D₄\n", "(4 rotations × 2 reflections) to each training patch, expanding the training\n", "set 8×. Crucially, each pair of (rotation, flip) operations is applied to\n", "*both* CIB and tSZ channels simultaneously, so the correlated spatial\n", "structure is preserved under symmetry.\n", "\n", "Note: the augmentation is applied **before** wrapping in a Dataset, so the\n", "full augmented set is held in memory. For very large N this can be chunked." ] }, { "cell_type": "code", "execution_count": null, "id": "8743f751", "metadata": {}, "outputs": [], "source": [ "# 8× augmentation via all elements of the dihedral group D4\n", "# (4 rotations × 2 flips, applied jointly to both channels).\n", "# Memory: augmented shape is (8*N_train, 2, 256, 256) ≈ 5 GB for N_train=1000.\n", "augmented = augment_images_unique(training_images)\n", "print(f\"After 8× augmentation: {len(training_images)} → {len(augmented)} training samples\")\n" ] }, { "cell_type": "markdown", "id": "e2ac3d70", "metadata": {}, "source": [ "## 5 Training loop\n", "\n", "`Trainer1D` handles optimisation, EMA weight averaging, checkpoint saving,\n", "and optional WandB logging. The default training recipe:\n", "- **Adam** optimiser, lr = 10⁻⁴ with no schedule.\n", "- EMA decay = 0.995; EMA weights used for sampling.\n", "- Checkpoints every 1000 steps; image grids logged to WandB at each checkpoint.\n", "\n", "Training for 20 000–50 000 steps is typical for convergence on this dataset;\n", "wall-clock time is ≈ 3–6 hours on a single A100 GPU (SLURM script:\n", "`train_slurm.sh`)." ] }, { "cell_type": "code", "execution_count": 5, "id": "80b3238e", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Trainer configured.\n", "To train: accelerate launch foregrounds_diffusion/train.py\n", "To resume: trainer.load()\n" ] } ], "source": [ "dataset = Dataset1D(augmented)\n", "\n", "trainer = Trainer1D(\n", " diffusion,\n", " dataset=dataset,\n", " train_batch_size=16,\n", " num_samples=1,\n", " train_lr=1e-4,\n", " train_num_steps=100_000,\n", " save_and_sample_every=5_000,\n", " gradient_accumulate_every=2, # effective batch size = 16 × 2 = 32\n", " ema_decay=0.995,\n", " amp=True, # fp16 mixed precision\n", ")\n", "print(\"Trainer configured.\")\n", "print(\"To train: accelerate launch foregrounds_diffusion/train.py\")\n", "print(\"To resume: trainer.load()\")\n" ] }, { "cell_type": "code", "execution_count": null, "id": "2563db63-a5ec-4376-9cbd-acee1010b862", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.13" } }, "nbformat": 4, "nbformat_minor": 5 }