04 — Model Architecture and Training¶
Purpose: Walk through the U-Net DDPM architecture and training configuration.
This notebook documents the model design and training setup without requiring a full training run. It covers:
Model definition — instantiates the
Unetwithdim=64,dim_mults=(1, 2, 4, 8),channels=2, andflash_attn=True, giving 35.7 M trainable parameters across four encoder/decoder stages. Wraps it inGaussianDiffusionwith 1000 timesteps, sigmoid noise schedule, and v-prediction objective.Data loading and augmentation — loads the stacked CIB+tSZ
.npyarrays, performs the 80/20 train/validation split, and applies the 8× augmentation (4 rotations × horizontal flip) viaaugment_images_unique.Trainer configuration — explains each
Trainer1Dhyperparameter: batch size 16, learning rate 1×10⁻⁴, 100,000 steps, gradient accumulation every 2 steps, EMA decay 0.995, mixed precision (fp16), checkpoint every 5,000 steps.
To actually train, run accelerate launch foregrounds_diffusion/train.py from the repo root instead of executing this notebook end-to-end.
Inputs:
CIB patches:
data/low_pass/2mJy/CIB_map_150GHz_256_st6_minmax_2mJy_zero_lp.npytSZ patches:
data/low_pass/2mJy/tSZ3_map_150GHz_256_st6_minmax_2mJy_norm_lp.npy
Outputs: model graph, parameter count table (no checkpoint written).
Key module functions: none — uses denoising_diffusion_pytorch directly.
Paper reference: §3.1 (DDPM framework), Appendix A (Table 1 — architecture details).
1 Setup¶
GPU and precision configuration. Flash attention is available only on CUDA devices with compute capability ≥ 8.0 (A100, H100, RTX 3090+). On CPU or older GPUs, set flash_attn=False in the U-Net constructor. Mixed-precision (fp16) training halves memory use and typically gives a ≈ 2× throughput improvement on modern GPUs with negligible loss-curve difference.
[ ]:
# 8× augmentation via all elements of the dihedral group D4
# (4 rotations × 2 flips, applied jointly to both channels).
# Memory: augmented shape is (8*N_train, 2, 256, 256) ≈ 5 GB for N_train=1000.
import numpy as np
import torch
from denoising_diffusion_pytorch import Unet, GaussianDiffusion, Trainer1D, Dataset1D
PTSRC = 2
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")
from foregrounds_diffusion.preprocessing import augment_images_unique
2 U-Net and diffusion process¶
The denoiser is a 2D U-Net with four resolution levels: 64 → 128 → 256 → 512 feature channels. The two-channel input (CIB, tSZ) is processed jointly so that cross-channel spatial correlations are learnt by every attention and convolution layer.
GaussianDiffusion wraps the U-Net with a sigmoid noise schedule over 1000 timesteps (T = 1000) — the default in denoising-diffusion-pytorch v2.2.5 (see also docs/paper_code_inconsistencies.md §noise schedule). The sigmoid schedule concentrates diffusion steps near t = 0 and t = T, where the signal-to-noise ratio changes most rapidly, and is combined with a v-prediction objective for improved sample stability.
[2]:
# U-Net architecture parameters.
# dim=64 : base channel width (doubles at each downsampling level)
# dim_mults : (1,2,4,8) → channel widths 64, 128, 256, 512
# channels=2 : CIB + tSZ processed jointly (not independently)
# flash_attn : use memory-efficient attention (requires CUDA + compute ≥ 8.0)
#
# GaussianDiffusion wraps the U-Net with a sigmoid noise schedule
# (denoising-diffusion-pytorch default; see docs/paper_code_inconsistencies.md).
# image_size and timesteps must match the values used during training.
unet = Unet(
dim=64,
dim_mults=(1, 2, 4, 8),
channels=2, # CIB + tSZ
flash_attn=True,
)
diffusion = GaussianDiffusion(
unet,
image_size=256,
timesteps=1000, # T = 1000 diffusion steps
)
diffusion = diffusion.to(device)
# Parameter count (paper reports 35.7 M)
total_params = sum(p.numel() for p in unet.parameters() if p.requires_grad)
print(f"Trainable parameters: {total_params:,} ({total_params / 1e6:.1f} M)")
Trainable parameters: 35,708,290 (35.7 M)
3 Load and split training data¶
Load the normalised .npy arrays, stack CIB and tSZ into a single (N, H, W, 2) channels-last tensor, then call split_data_to_tensors to produce train / val / test splits (80 / 10 / 10 by default, seeded at 42). The function transposes to PyTorch channels-first (N, 2, H, W) internally.
[3]:
from pathlib import Path
PATCHES_DIR = Path(f"data/low_pass/{PTSRC}mJy")
fpath_cib = PATCHES_DIR / f"CIB_map_150GHz_256_st6_minmax_{PTSRC}mJy_zero_lp.npy"
fpath_tsz = PATCHES_DIR / f"tSZ3_map_150GHz_256_st6_minmax_{PTSRC}mJy_norm_lp.npy"
cib_maps = np.load(fpath_cib) # (N, H, W, 1)
tsz_maps = np.load(fpath_tsz) # (N, H, W, 1)
cut_maps = np.concatenate([cib_maps, tsz_maps], axis=-1) # (N, H, W, 2)
cut_maps = cut_maps.transpose(0, 3, 1, 2) # (N, 2, H, W)
print(f"Stacked patches: {cut_maps.shape}")
# 80 / 20 train / validation split (seeded for reproducibility)
rng = np.random.default_rng(seed=42)
indices = rng.permutation(len(cut_maps))
num_train = int(0.8 * len(cut_maps))
training_images = torch.tensor(cut_maps[indices[:num_train]], dtype=torch.float32)
val_images = torch.tensor(cut_maps[indices[num_train:]], dtype=torch.float32)
print(f"Train: {len(training_images)}, Val: {len(val_images)}")
Stacked patches: (674, 2, 256, 256)
Train: 539, Val: 135
4 Data augmentation¶
augment_images_unique applies all 8 elements of the dihedral group D₄ (4 rotations × 2 reflections) to each training patch, expanding the training set 8×. Crucially, each pair of (rotation, flip) operations is applied to both CIB and tSZ channels simultaneously, so the correlated spatial structure is preserved under symmetry.
Note: the augmentation is applied before wrapping in a Dataset, so the full augmented set is held in memory. For very large N this can be chunked.
[ ]:
# 8× augmentation via all elements of the dihedral group D4
# (4 rotations × 2 flips, applied jointly to both channels).
# Memory: augmented shape is (8*N_train, 2, 256, 256) ≈ 5 GB for N_train=1000.
augmented = augment_images_unique(training_images)
print(f"After 8× augmentation: {len(training_images)} → {len(augmented)} training samples")
5 Training loop¶
Trainer1D handles optimisation, EMA weight averaging, checkpoint saving, and optional WandB logging. The default training recipe:
Adam optimiser, lr = 10⁻⁴ with no schedule.
EMA decay = 0.995; EMA weights used for sampling.
Checkpoints every 1000 steps; image grids logged to WandB at each checkpoint.
Training for 20 000–50 000 steps is typical for convergence on this dataset; wall-clock time is ≈ 3–6 hours on a single A100 GPU (SLURM script: train_slurm.sh).
[5]:
dataset = Dataset1D(augmented)
trainer = Trainer1D(
diffusion,
dataset=dataset,
train_batch_size=16,
num_samples=1,
train_lr=1e-4,
train_num_steps=100_000,
save_and_sample_every=5_000,
gradient_accumulate_every=2, # effective batch size = 16 × 2 = 32
ema_decay=0.995,
amp=True, # fp16 mixed precision
)
print("Trainer configured.")
print("To train: accelerate launch foregrounds_diffusion/train.py")
print("To resume: trainer.load(<step_number>)")
Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
Trainer configured.
To train: accelerate launch foregrounds_diffusion/train.py
To resume: trainer.load(<step_number>)
[ ]: