04 — Model Architecture and Training

Purpose: Walk through the U-Net DDPM architecture and training configuration.

This notebook documents the model design and training setup without requiring a full training run. It covers:

  1. Model definition — instantiates the Unet with dim=64, dim_mults=(1, 2, 4, 8), channels=2, and flash_attn=True, giving 35.7 M trainable parameters across four encoder/decoder stages. Wraps it in GaussianDiffusion with 1000 timesteps, sigmoid noise schedule, and v-prediction objective.

  2. Data loading and augmentation — loads the stacked CIB+tSZ .npy arrays, performs the 80/20 train/validation split, and applies the 8× augmentation (4 rotations × horizontal flip) via augment_images_unique.

  3. Trainer configuration — explains each Trainer1D hyperparameter: batch size 16, learning rate 1×10⁻⁴, 100,000 steps, gradient accumulation every 2 steps, EMA decay 0.995, mixed precision (fp16), checkpoint every 5,000 steps.

To actually train, run accelerate launch foregrounds_diffusion/train.py from the repo root instead of executing this notebook end-to-end.

Inputs:

  • CIB patches: data/low_pass/2mJy/CIB_map_150GHz_256_st6_minmax_2mJy_zero_lp.npy

  • tSZ patches: data/low_pass/2mJy/tSZ3_map_150GHz_256_st6_minmax_2mJy_norm_lp.npy

Outputs: model graph, parameter count table (no checkpoint written).

Key module functions: none — uses denoising_diffusion_pytorch directly.

Paper reference: §3.1 (DDPM framework), Appendix A (Table 1 — architecture details).

1 Setup

GPU and precision configuration. Flash attention is available only on CUDA devices with compute capability ≥ 8.0 (A100, H100, RTX 3090+). On CPU or older GPUs, set flash_attn=False in the U-Net constructor. Mixed-precision (fp16) training halves memory use and typically gives a ≈ 2× throughput improvement on modern GPUs with negligible loss-curve difference.

[ ]:
# 8× augmentation via all elements of the dihedral group D4
# (4 rotations × 2 flips, applied jointly to both channels).
# Memory: augmented shape is (8*N_train, 2, 256, 256) ≈ 5 GB for N_train=1000.
import numpy as np
import torch
from denoising_diffusion_pytorch import Unet, GaussianDiffusion, Trainer1D, Dataset1D

PTSRC = 2
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

from foregrounds_diffusion.preprocessing import augment_images_unique

2 U-Net and diffusion process

The denoiser is a 2D U-Net with four resolution levels: 64 → 128 → 256 → 512 feature channels. The two-channel input (CIB, tSZ) is processed jointly so that cross-channel spatial correlations are learnt by every attention and convolution layer.

GaussianDiffusion wraps the U-Net with a sigmoid noise schedule over 1000 timesteps (T = 1000) — the default in denoising-diffusion-pytorch v2.2.5 (see also docs/paper_code_inconsistencies.md §noise schedule). The sigmoid schedule concentrates diffusion steps near t = 0 and t = T, where the signal-to-noise ratio changes most rapidly, and is combined with a v-prediction objective for improved sample stability.

[2]:
# U-Net architecture parameters.
# dim=64        : base channel width (doubles at each downsampling level)
# dim_mults     : (1,2,4,8) → channel widths 64, 128, 256, 512
# channels=2    : CIB + tSZ processed jointly (not independently)
# flash_attn    : use memory-efficient attention (requires CUDA + compute ≥ 8.0)
#
# GaussianDiffusion wraps the U-Net with a sigmoid noise schedule
# (denoising-diffusion-pytorch default; see docs/paper_code_inconsistencies.md).
# image_size and timesteps must match the values used during training.
unet = Unet(
    dim=64,
    dim_mults=(1, 2, 4, 8),
    channels=2,          # CIB + tSZ
    flash_attn=True,
)

diffusion = GaussianDiffusion(
    unet,
    image_size=256,
    timesteps=1000,      # T = 1000 diffusion steps
)
diffusion = diffusion.to(device)

# Parameter count (paper reports 35.7 M)
total_params = sum(p.numel() for p in unet.parameters() if p.requires_grad)
print(f"Trainable parameters: {total_params:,}  ({total_params / 1e6:.1f} M)")

Trainable parameters: 35,708,290  (35.7 M)

3 Load and split training data

Load the normalised .npy arrays, stack CIB and tSZ into a single (N, H, W, 2) channels-last tensor, then call split_data_to_tensors to produce train / val / test splits (80 / 10 / 10 by default, seeded at 42). The function transposes to PyTorch channels-first (N, 2, H, W) internally.

[3]:
from pathlib import Path
PATCHES_DIR = Path(f"data/low_pass/{PTSRC}mJy")
fpath_cib = PATCHES_DIR / f"CIB_map_150GHz_256_st6_minmax_{PTSRC}mJy_zero_lp.npy"
fpath_tsz = PATCHES_DIR / f"tSZ3_map_150GHz_256_st6_minmax_{PTSRC}mJy_norm_lp.npy"

cib_maps = np.load(fpath_cib)  # (N, H, W, 1)
tsz_maps = np.load(fpath_tsz)  # (N, H, W, 1)
cut_maps = np.concatenate([cib_maps, tsz_maps], axis=-1)  # (N, H, W, 2)
cut_maps = cut_maps.transpose(0, 3, 1, 2)                 # (N, 2, H, W)
print(f"Stacked patches: {cut_maps.shape}")

# 80 / 20 train / validation split (seeded for reproducibility)
rng = np.random.default_rng(seed=42)
indices  = rng.permutation(len(cut_maps))
num_train = int(0.8 * len(cut_maps))
training_images = torch.tensor(cut_maps[indices[:num_train]], dtype=torch.float32)
val_images      = torch.tensor(cut_maps[indices[num_train:]], dtype=torch.float32)
print(f"Train: {len(training_images)},  Val: {len(val_images)}")

Stacked patches: (674, 2, 256, 256)
Train: 539,  Val: 135

4 Data augmentation

augment_images_unique applies all 8 elements of the dihedral group D₄ (4 rotations × 2 reflections) to each training patch, expanding the training set 8×. Crucially, each pair of (rotation, flip) operations is applied to both CIB and tSZ channels simultaneously, so the correlated spatial structure is preserved under symmetry.

Note: the augmentation is applied before wrapping in a Dataset, so the full augmented set is held in memory. For very large N this can be chunked.

[ ]:
# 8× augmentation via all elements of the dihedral group D4
# (4 rotations × 2 flips, applied jointly to both channels).
# Memory: augmented shape is (8*N_train, 2, 256, 256) ≈ 5 GB for N_train=1000.
augmented = augment_images_unique(training_images)
print(f"After 8× augmentation: {len(training_images)} → {len(augmented)} training samples")

5 Training loop

Trainer1D handles optimisation, EMA weight averaging, checkpoint saving, and optional WandB logging. The default training recipe:

  • Adam optimiser, lr = 10⁻⁴ with no schedule.

  • EMA decay = 0.995; EMA weights used for sampling.

  • Checkpoints every 1000 steps; image grids logged to WandB at each checkpoint.

Training for 20 000–50 000 steps is typical for convergence on this dataset; wall-clock time is ≈ 3–6 hours on a single A100 GPU (SLURM script: train_slurm.sh).

[5]:
dataset = Dataset1D(augmented)

trainer = Trainer1D(
    diffusion,
    dataset=dataset,
    train_batch_size=16,
    num_samples=1,
    train_lr=1e-4,
    train_num_steps=100_000,
    save_and_sample_every=5_000,
    gradient_accumulate_every=2,    # effective batch size = 16 × 2 = 32
    ema_decay=0.995,
    amp=True,                       # fp16 mixed precision
)
print("Trainer configured.")
print("To train:  accelerate launch foregrounds_diffusion/train.py")
print("To resume: trainer.load(<step_number>)")

Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
Trainer configured.
To train:  accelerate launch foregrounds_diffusion/train.py
To resume: trainer.load(<step_number>)
[ ]: