{ "cells": [ { "cell_type": "markdown", "id": "aad143ec", "metadata": {}, "source": "# Paper figures\n\nGenerated from AGORA simulations and trained DDPM checkpoint.\nRun on the cluster where the data files are available.\n\n> §8 plot-style rewrite (Wong palette, cividis cmaps, PDF output) is tracked separately." }, { "cell_type": "code", "execution_count": null, "id": "2ffed9cf", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import healpy as hp\n", "import matplotlib.pyplot as plt\n", "import matplotlib.image as mpimg\n", "import matplotlib.patches as patches\n", "from matplotlib.lines import Line2D\n", "from matplotlib.gridspec import GridSpec\n", "from matplotlib.offsetbox import OffsetImage, AnnotationBbox\n", "from mpl_toolkits.axes_grid1.inset_locator import inset_axes\n", "from scipy.ndimage import gaussian_filter1d\n", "from pathlib import Path\n", "from random import sample\n", "\n", "from foregrounds_diffusion.flatmaps import map2cl, cl2map, get_lxly\n", "from foregrounds_diffusion.plot_style import apply as _apply_style, WONG\n", "from foregrounds_diffusion.preprocessing import (\n", " apply_stdnorm,\n", " renormalize_dm_maps,\n", " denormalize_dm_maps,\n", " load_all_moments,\n", ")\n", "from foregrounds_diffusion.statistics import stats\n", "\n", "plt.rcParams.update({\n", " \"text.usetex\": True,\n", " \"font.family\": \"serif\",\n", " \"font.serif\": [\"Computer Modern Roman\", \"DejaVu Serif\"],\n", " \"font.size\": 12,\n", " \"figure.dpi\": 150,\n", " \"savefig.dpi\": 300,\n", " \"savefig.bbox\": \"tight\",\n", "})\n" ] }, { "cell_type": "code", "execution_count": null, "id": "afc96d26", "metadata": {}, "outputs": [], "source": [ "PROJECT_ROOT = Path(\"/home/apb86/cmb_foregrounds_diffusion\")\n", "DATA_DIR = PROJECT_ROOT / \"data\"\n", "PATCHES_DIR = DATA_DIR / \"low_pass\" / \"2mJy\"\n", "FIGURES_DIR = PROJECT_ROOT / \"plots\" / \"paper\"\n", "FIGURES_DIR.mkdir(parents=True, exist_ok=True)\n", "\n", "PTSRC = 2\n", "RES = 256\n", "flatskymapparams = [RES, RES, 1.40625, 1.40625] # [nx, ny, dx_arcmin, dy_arcmin]\n", "\n", "WONG = _apply_style(fig_width_pt=510.0, n_cols=2)\n" ] }, { "cell_type": "markdown", "id": "d21c5abe-a720-42bc-9b8e-11172675c2e4", "metadata": {}, "source": "## Fig 1" }, { "cell_type": "code", "execution_count": null, "id": "9999bc2a-55f3-402c-9a09-3c2d941dfddb", "metadata": {}, "outputs": [], "source": [ "cib_150_map = hp.read_map(DATA_DIR / \"mask_radio_cib_2mJy\" / \"mdpl2_150GHz_fullsky.fits\")" ] }, { "cell_type": "code", "execution_count": null, "id": "a10c3a44-c814-4fcc-b64a-36de8b6ce2d7", "metadata": {}, "outputs": [], "source": [ "hp.mollview(cib_150_map, title=\"\", min=0, max=60, unit=r\"$\\mu K$\", cmap='cividis')\n", "fig = plt.gcf()\n", "cax = fig.axes[-1]\n", "cax.tick_params(labelsize=20)\n", "plt.savefig(\"figures/cib_150_map.png\", bbox_inches=\"tight\")\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "c2aded2a-c77b-4dd1-990f-654c6cc7778e", "metadata": {}, "outputs": [], "source": [ "cib_150_map_filtered = hp.read_map(DATA_DIR / \"mask_radio_cib_2mJy\" / \"mdpl2_150GHz_fullsky_lmax7000.fits\")" ] }, { "cell_type": "code", "execution_count": null, "id": "a660ef64-a635-4474-bc4d-967db8b7f488", "metadata": {}, "outputs": [], "source": [ "hp.mollview(cib_150_map_filtered, title=\"\", min=0, max=60, unit=r\"$\\mu K$\", cmap='cividis')\n", "fig = plt.gcf()\n", "cax = fig.axes[-1]\n", "cax.tick_params(labelsize=20,)\n", "plt.savefig(\"figures/cib_150_map_processed.png\", bbox_inches=\"tight\")\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "f8970244-154a-4a1b-af7f-eb7f52dd4de8", "metadata": {}, "outputs": [], "source": [ "fname = PATCHES_DIR / \"CIB_map_150GHz_256_st6_minmax_2mJy_zero_lp.npy\"\n", "processed_maps = np.load(fname);" ] }, { "cell_type": "code", "execution_count": null, "id": "d0d8bafd-8d29-480e-8ce7-26e12704f989", "metadata": {}, "outputs": [], "source": [ "fig, ax = plt.subplots(figsize=(5, 5))\n", "im = ax.imshow(processed_maps[0], cmap='cividis', vmin=0, vmax=1)\n", "ax.set_xlabel(r\"$6^\\circ$\", fontsize=20)\n", "ax.set_ylabel(r\"$6^\\circ$\", fontsize=20)\n", "ax.grid()\n", "\n", "cbar = fig.colorbar(im, ax=ax, orientation='horizontal',\n", " fraction=0.046)\n", "cbar.set_ticks([0, 1])\n", "cbar.ax.tick_params(labelsize=20)\n", "ax.tick_params(axis='both', which='both',\n", " bottom=False, top=False,\n", " left=False, right=False,\n", " labelbottom=False, labelleft=False)\n", "\n", "plt.savefig(\"figures/cib_150_map_flatsky.png\", bbox_inches=\"tight\")\n", "plt.show()\n" ] }, { "cell_type": "code", "execution_count": null, "id": "92758366-a519-43a7-9846-69023394197e", "metadata": {}, "outputs": [], "source": [ "fig, ax = plt.subplots(figsize=(12, 4))\n", "ax.axis('off')\n", "\n", "img_fullsky1 = mpimg.imread('figures/cib_150_map.png')\n", "img_fullsky2 = mpimg.imread('figures/cib_150_map_processed.png')\n", "img_flatsky = mpimg.imread('figures/cib_150_map_flatsky.png')\n", "\n", "nodes = {\n", " 'node1': (0.05, 0.35, 0.1, 0.3),\n", " 'node2': (0.43, 0.35, 0.1, 0.3),\n", " 'node3': (0.67, 0.35, 0.25, 0.3),\n", "}\n", "\n", "def place_image(img, coords, zoom=0.1):\n", " x = coords[0] + coords[2] / 2\n", " y = coords[1] + coords[3] / 2\n", " im = OffsetImage(img, zoom=zoom)\n", " ab = AnnotationBbox(im, (x, y), frameon=False)\n", " ax.add_artist(ab)\n", "\n", "place_image(img_fullsky1, nodes['node1'], zoom=0.1)\n", "place_image(img_fullsky2, nodes['node2'], zoom=0.1)\n", "place_image(img_flatsky, nodes['node3'], zoom=0.1)\n", "\n", "node1_bottom = (nodes['node1'][0] + nodes['node1'][2] / 2, nodes['node1'][1])\n", "node2_bottom = (nodes['node2'][0] + nodes['node2'][2] / 2, nodes['node2'][1])\n", "node3_bottom = (nodes['node3'][0] + nodes['node3'][2] / 2, nodes['node3'][1])\n", "\n", "arrow_offset = -0.3\n", "arrow_start12 = (node1_bottom[0] +0.1 , node1_bottom[1] + arrow_offset)\n", "arrow_end12 = (node2_bottom[0]-0.1, node2_bottom[1] + arrow_offset)\n", "arrow_start23 = (node2_bottom[0]+0.05, node2_bottom[1] + arrow_offset)\n", "arrow_end23 = (node3_bottom[0]-0.05, node3_bottom[1] + arrow_offset)\n", "\n", "ax.annotate(\"\", xy=arrow_end12, xytext=arrow_start12, \n", " arrowprops=dict(arrowstyle=\"->\", lw=1, shrinkA=5, shrinkB=5))\n", "midpoint12 = ((arrow_start12[0] + arrow_end12[0]) / 2, (arrow_start12[1] + arrow_end12[1]) / 2)\n", "ax.text(midpoint12[0], midpoint12[1] + 0.05, \"2mJy point source mask\", ha='center', fontsize=12)\n", "ax.text(midpoint12[0], midpoint12[1] - 0.09, \"Low pass filter at 7000\", ha='center', fontsize=12)\n", "\n", "ax.annotate(\"\", xy=arrow_end23, xytext=arrow_start23, \n", " arrowprops=dict(arrowstyle=\"->\", lw=1, shrinkA=5, shrinkB=5))\n", "midpoint23 = ((arrow_start23[0] + arrow_end23[0]) / 2, (arrow_start23[1] + arrow_end23[1]) / 2)\n", "ax.text(midpoint23[0], midpoint23[1] + 0.05, \"Extracting flatsky patches\", ha='center', fontsize=12)\n", "ax.text(midpoint23[0], midpoint23[1] - 0.09, \"Normalization 0 to 1\", ha='center', fontsize=12)\n", "plt.savefig(\"figures/map_processing.png\", bbox_inches=\"tight\")\n", "plt.savefig(\"figures/map_processing.pdf\", bbox_inches=\"tight\")\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "1847a8d7-e287-4ccb-994d-42b609a3c927", "metadata": {}, "source": [ "## Multifrequency maps" ] }, { "cell_type": "code", "execution_count": null, "id": "696f652f-8c57-4772-acc1-4436baa7e0c7", "metadata": {}, "outputs": [], "source": [ "channels = [\"CIB95\", \"CIB150\", \"CIB857\"]\n", "filenames = [PATCHES_DIR / f\"cut_maps_RES_256_ANG_X_6.0_deg_2mJy_lp_{ch}.npy\" for ch in channels]\n", "\n", "train_maps_list = [np.load(fname) for fname in filenames]\n", "train_maps = np.concatenate(train_maps_list, axis=-1)" ] }, { "cell_type": "code", "execution_count": null, "id": "62892013-1a3e-450e-81ba-f6b94642b707", "metadata": {}, "outputs": [], "source": [ "dm_maps = np.load(PATCHES_DIR / \"new_samples_14_cib_2mJy_zero_6x6_w_au_lp_three.npy\")\n", "dm_maps = renormalize_dm_maps(dm_maps, train_maps, variance_scaling=True)" ] }, { "cell_type": "code", "execution_count": null, "id": "122924b7-9da1-428c-831f-4e3845c64cc8", "metadata": {}, "outputs": [], "source": [ "example_idx = 100\n", "example_agora = apply_stdnorm(train_maps[example_idx])\n", "example_diffusion = apply_stdnorm(dm_maps[example_idx])\n", "cmap='cividis'\n", "\n", "plt.clf()\n", "plt.figure(figsize=(10, 7))\n", "plt.subplots_adjust(wspace=0.01)\n", "\n", "titles = [\"CIB 95 GHz\", \"CIB 150 GHz\", \"CIB 857 GHz\"]\n", "vmin_vmax_dic = {0: (-1, 4), 1: (-1, 4), 2: (-1, 4)}\n", "#vmin_vmax_dic = {0: (None, None), 1: (None, None), 2: (None, None)}\n", "\n", "sp_i = 1\n", "for data, model in zip([example_agora, example_diffusion], [\"Agora\", \"DDPM\"]):\n", " for i in range(3):\n", " plt.subplot(2, 3, sp_i)\n", " vmin, vmax = vmin_vmax_dic[i]\n", " plt.imshow(data[:, :, i], cmap=cmap, aspect='auto', vmin=vmin, vmax=vmax)\n", " title(f\"{model}: {titles[i]}\", fontsize=20)\n", " axis('off')\n", " sp_i += 1\n", "\n", "plt.tight_layout()\n", "plt.savefig(\"figures/example_cib_triplet.pdf\", bbox_inches=\"tight\")\n", "plt.savefig(\"figures/example_cib_triplet.png\", bbox_inches=\"tight\")\n", "plt.show()\n" ] }, { "cell_type": "code", "execution_count": null, "id": "25646b5c-9a1e-4428-b493-2cca105ceefa", "metadata": {}, "outputs": [], "source": [ "def correlation_coeff(cross, auto1, auto2):\n", " return np.array(cross) / np.sqrt(np.array(auto1) * np.array(auto2))" ] }, { "cell_type": "code", "execution_count": null, "id": "11409d6d-4ade-4843-9505-47d747ded46e", "metadata": {}, "outputs": [], "source": [ "n_samples = 200\n", "idxs = sample(range(dm_maps.shape[0]), n_samples)\n", "cross_pairs = [(0, 1), (0, 2), (1, 2)]\n", "\n", "auto_cls = {'train': {i: [] for i in range(3)},\n", " 'dm': {i: [] for i in range(3)}}\n", "cross_cls = {'train': {p: [] for p in cross_pairs},\n", " 'dm': {p: [] for p in cross_pairs}}\n", "\n", "ell, _ = map2cl(flatskymapparams, train_maps[0, :, :, 0])\n", "\n", "for i in idxs:\n", " for tag, maps in zip(['train', 'dm'], [train_maps, dm_maps]):\n", " for ch in range(3):\n", " m = maps[i, :, :, ch]\n", " _, cl = map2cl(flatskymapparams, m)\n", " auto_cls[tag][ch].append(cl)\n", " for ch1, ch2 in cross_pairs:\n", " m1, m2 = maps[i, :, :, ch1], maps[i, :, :, ch2]\n", " _, cl = map2cl(flatskymapparams, m1, m2)\n", " cross_cls[tag][(ch1, ch2)].append(cl)\n", "\n", "mean_auto_cls = {tag: {k: mean(v, axis=0) for k, v in auto.items()} for tag, auto in auto_cls.items()}\n", "mean_cross_cls = {tag: {k: mean(v, axis=0) for k, v in cross.items()} for tag, cross in cross_cls.items()}\n", "\n", "corr = {}\n", "for tag in ['train', 'dm']:\n", " corr[tag] = {\n", " (ch1, ch2): mean(\n", " correlation_coeff(cross_cls[tag][(ch1, ch2)],\n", " auto_cls[tag][ch1],\n", " auto_cls[tag][ch2]), axis=0)\n", " for (ch1, ch2) in cross_pairs\n", " }" ] }, { "cell_type": "code", "execution_count": null, "id": "3cf359c2-fff2-4d9c-93c6-b61ea5babe31", "metadata": {}, "outputs": [], "source": [ "from matplotlib import colormaps as cm\n", "pairs = [(0, 0), (0, 1), (0, 2), (1, 1), (1, 2), (2, 2)]\n", "labels = [\"95x95\", \"95x150\", \"95x857\", \"150x150\", \"150x857\", \"857x857\"]\n", "cmap = cm[\"RdBu_r\"]\n", "cmap_indices = [0.0, 0.1, 0.3, 0.7, 0.9, 1.0] \n", "colors = [cmap(i) for i in cmap_indices]\n", "\n", "fsval=14\n", "\n", "plt.clf()\n", "plt.figure(figsize=(10., 4.2))\n", "\n", "plt.subplot(1, 2, 1)\n", "for i, (ch1, ch2) in enumerate(pairs):\n", " if ch1 == ch2:\n", " cl_train = mean_auto_cls[\"train\"][ch1]\n", " cl_dm = mean_auto_cls[\"dm\"][ch1]\n", " else:\n", " cl_train = mean_cross_cls[\"train\"][(ch1, ch2)]\n", " cl_dm = mean_cross_cls[\"dm\"][(ch1, ch2)]\n", "\n", " plot(ell, (cl_train / cl_dm) - 1, lw=1, label=labels[i], color=colors[i])\n", "\n", "axhline(0, color='gray', ls='--', lw=1)\n", "xlabel(r\"Multipole $\\ell$\")\n", "plt.ylabel(r\"$C_\\ell^{\\mathrm{Agora}} / C_\\ell^{\\mathrm{DDPM}} - 1$\")\n", "xlim(300, 4000)\n", "ylim(-0.1, 0.1)\n", "legend(fontsize=fsval, ncol=2)\n", "\n", "plt.subplot(1, 2, 2)\n", "for i, (ch1, ch2) in enumerate(pairs):\n", " if ch1 == ch2:\n", " continue \n", "\n", " plot(ell, corr[\"train\"][(ch1, ch2)], lw=2, label=f\"{labels[i]} (Agora)\", color=colors[i])\n", " plot(ell, corr[\"dm\"][(ch1, ch2)], lw=2, ls='--', label=f\"{labels[i]} (DM)\", color=colors[i])\n", " \n", "xlabel(r\"Multipole $\\ell$\")\n", "plt.ylabel(\"Correlation Coefficient\")\n", "xlim(300, 4000)\n", "ylim(0.65, 1.05)\n", "axhline(1, color='gray', ls='--', lw=1)\n", "\n", "handles = [\n", " Line2D([0], [0], color='k', lw=2, linestyle='-', label='Agora'),\n", " Line2D([0], [0], color='k', lw=2, linestyle='--', label='DDPM')\n", "]\n", "\n", "legend(\n", " handles=handles,\n", " #ncols=3,\n", " fontsize=fsval,\n", " #loc='upper right',\n", ")\n", "\n", "plt.tight_layout()\n", "savefig(\"figures/cib_triplet_errors_and_correlation.pdf\", bbox_inches=\"tight\")\n", "savefig(\"figures/cib_triplet_errors_and_correlation.png\", bbox_inches=\"tight\")\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "dc49aed0-abba-4bd5-88d5-5c4cb9be30be", "metadata": {}, "source": [ "## tSZ-CIB maps" ] }, { "cell_type": "code", "execution_count": null, "id": "9739c0a7-a8f2-4fc3-b6d4-562964d9aaf3", "metadata": {}, "outputs": [], "source": [ "filenames = [PATCHES_DIR / f\"cut_maps_RES_256_ANG_X_6.0_deg_2mJy_lp_CIB150.npy\",\n", " PATCHES_DIR / f\"cut_maps_RES_256_ANG_X_6.0_deg_2mJy_lp_tsz3.npy\"]\n", "\n", "agora_maps_list = [np.load(fname) for fname in filenames]\n", "agora_maps = np.concatenate(agora_maps_list, axis=-1)" ] }, { "cell_type": "code", "execution_count": null, "id": "03274c60-34e6-4187-98da-6a3e467d7378", "metadata": {}, "outputs": [], "source": [ "num_samples = len(agora_maps)\n", "num_train = int(0.8 * num_samples)\n", "rng = np.random.default_rng(seed=42)\n", "indices = rng.permutation(num_samples)\n", "train_indices = indices[:num_train]\n", "test_indices = indices[num_train:]\n", "train_maps = agora_maps[train_indices]\n", "test_maps = agora_maps[test_indices]" ] }, { "cell_type": "code", "execution_count": null, "id": "60a4f0c2-6719-4226-b84f-8971c76d47e2", "metadata": {}, "outputs": [], "source": [ "dm_maps = np.load(PATCHES_DIR / \"new_samples_16_cib_tsz3_2mJy_zero_norm_6x6_w_au_lp.npy\")\n", "dm_maps_unscaled = renormalize_dm_maps(dm_maps, train_maps, variance_scaling=False)" ] }, { "cell_type": "code", "execution_count": null, "id": "40a34ef4-b73c-4a74-a177-6d7063ad6bb0", "metadata": {}, "outputs": [], "source": [ "dm_maps = np.load(PATCHES_DIR / \"new_samples_16_cib_tsz3_2mJy_zero_norm_6x6_w_au_lp.npy\")\n", "dm_maps = renormalize_dm_maps(dm_maps, train_maps, variance_scaling=True)" ] }, { "cell_type": "code", "execution_count": null, "id": "f135d11e-f3b7-4d64-baab-8f28d3773c8b", "metadata": {}, "outputs": [], "source": [ "#np.save(\"data/low_pass/2mJy/new_samples_16_cib_tsz3_2mJy_zero_norm_6x6_w_au_lp_varscaled_no_pickle.npy\", dm_maps, allow_pickle=False)" ] }, { "cell_type": "code", "execution_count": null, "id": "7da701a4-1c87-4a0d-abc1-165d20bf22af", "metadata": {}, "outputs": [], "source": [ "filenames = [PATCHES_DIR / f\"cut_maps_RES_256_ANG_X_6.0_deg_2mJy_lp_gaussian_cib_joint3.npy\",\n", " PATCHES_DIR / f\"cut_maps_RES_256_ANG_X_6.0_deg_2mJy_lp_gaussian_tsz_joint3.npy\"]\n", "\n", "gaussian_maps_list = [np.load(fname) for fname in filenames]\n", "gaussian_maps = np.concatenate(gaussian_maps_list, axis=-1)" ] }, { "cell_type": "code", "execution_count": null, "id": "c7c7d601-bc46-4b14-88ed-833d436888c9", "metadata": {}, "outputs": [], "source": [ "train_maps_std = apply_stdnorm(train_maps)\n", "dm_maps_std = apply_stdnorm(dm_maps)" ] }, { "cell_type": "code", "execution_count": null, "id": "554c32b5-d743-48f3-bea1-3a13c58565c9", "metadata": {}, "outputs": [], "source": [ "stats(train_maps[:,:,:,1])" ] }, { "cell_type": "code", "execution_count": null, "id": "78f525b8-629e-4df5-885c-624e92111810", "metadata": {}, "outputs": [], "source": [ "example_idxs = [68, 132, 85, 100]\n", "\n", "vmin_cib, vmax_cib = -1, 4\n", "vmin_tsz, vmax_tsz = -10, 3\n", "cmap_cib = 'cividis'\n", "cmap_tsz = 'RdBu_r'\n", "fsval = 16 # slightly larger y-axis labels\n", "sim_labels = [\"Sim1\", \"Sim2\", \"Sim3\", \"Sim4\"]\n", "\n", "plt.clf()\n", "plt.figure(figsize=(10, 10))\n", "\n", "for i, idx in enumerate(example_idxs):\n", " for row, data, cmap, vmin, vmax, label in zip(\n", " range(4),\n", " [train_maps_std, dm_maps_std, train_maps_std, dm_maps_std],\n", " [cmap_cib, cmap_cib, cmap_tsz, cmap_tsz],\n", " [vmin_cib, vmin_cib, vmin_tsz, vmin_tsz],\n", " [vmax_cib, vmax_cib, vmax_tsz, vmax_tsz],\n", " [\"CIB Agora\", \"CIB DDPM\", \"tSZ Agora\", \"tSZ DDPM\"]\n", " ):\n", " ax = plt.subplot(4, 4, i + 1 + row * 4)\n", " plt.imshow(data[idx, :, :, row // 2], cmap=cmap, vmin=vmin, vmax=vmax, aspect='auto')\n", " ax.set_xticks([])\n", " ax.set_yticks([])\n", "\n", " if i == 0:\n", " plt.ylabel(label, fontsize=fsval+2)\n", "\n", " \n", " ax.text(\n", " 0.95, 0.85, sim_labels[i],\n", " transform=ax.transAxes,\n", " fontsize=14,\n", " color='black',\n", " ha='right',\n", " va='bottom',\n", " bbox=dict(boxstyle=\"round,pad=0.2\", facecolor=\"white\", alpha=1)\n", " )\n", "\n", "plt.tight_layout()\n", "savefig(\"figures/examples_grid_labeled.pdf\", bbox_inches=\"tight\")\n", "savefig(\"figures/examples_grid_labeled.png\", bbox_inches=\"tight\")\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "c8836a77-7b70-42a6-ad82-efe00df94bc4", "metadata": {}, "outputs": [], "source": [ "example_idxs = [68, 132, 85, 100]\n", "\n", "vmin_cib, vmax_cib = -1, 4\n", "vmin_tsz, vmax_tsz = -10, 3\n", "cmap_cib = 'cividis'\n", "cmap_tsz = 'RdBu_r'\n", "fsval = 16\n", "sim_labels = [\"Sim1\", \"Sim2\", \"Sim3\", \"Sim4\"]\n", "\n", "plt.clf()\n", "plt.figure(figsize=(10, 10))\n", "\n", "for i, idx in enumerate(example_idxs):\n", " for row, data, cmap, vmin, vmax, label in zip(\n", " range(4),\n", " [train_maps_std, dm_maps_std, train_maps_std, dm_maps_std],\n", " [cmap_cib, cmap_cib, cmap_tsz, cmap_tsz],\n", " [vmin_cib, vmin_cib, vmin_tsz, vmin_tsz],\n", " [vmax_cib, vmax_cib, vmax_tsz, vmax_tsz],\n", " [\"CIB Agora\", \"CIB DDPM\", \"tSZ Agora\", \"tSZ DDPM\"]\n", " ):\n", " ax = plt.subplot(4, 4, i + 1 + row * 4)\n", " plt.imshow(data[idx, :, :, row // 2], cmap=cmap, vmin=vmin, vmax=vmax, aspect='auto')\n", " ax.set_xticks([])\n", " ax.set_yticks([])\n", "\n", " if i == 0:\n", " plt.ylabel(label, fontsize=fsval + 2)\n", "\n", " ax.text(\n", " 0.95, 0.85, sim_labels[i],\n", " transform=ax.transAxes,\n", " fontsize=14,\n", " color='black',\n", " ha='right',\n", " va='bottom',\n", " bbox=dict(boxstyle=\"round,pad=0.2\", facecolor=\"white\", alpha=1)\n", " )\n", "\n", "# Add zoomed-in insets to subplot (3,2) and (3,3), which are subplot indices 10 and 11 (0-based)\n", "for subplot_idx, zoom_coords in zip([14, 15], [(0.6, 0.6), (0.3, 0.3)]):\n", " ax_main = plt.gcf().axes[subplot_idx]\n", "\n", " # Create inset axes\n", " ax_inset = inset_axes(\n", " ax_main,\n", " width=\"30%\", height=\"30%\", loc='upper left',\n", " bbox_to_anchor=(1.05, 0.5, 0.3, 0.3), # (x0, y0, width, height) in axes fraction\n", " bbox_transform=ax_main.transAxes,\n", " borderpad=0\n", ")\n", "\n", " # Get image data and zoom in arbitrarily (adjust indices as needed)\n", " im_data = ax_main.images[0].get_array()\n", " zoom_slice = im_data[35:55, 35:55] # You can change this\n", "\n", " ax_inset.imshow(\n", " zoom_slice,\n", " cmap=ax_main.images[0].get_cmap(),\n", " vmin=ax_main.images[0].get_clim()[0],\n", " vmax=ax_main.images[0].get_clim()[1],\n", " aspect='auto'\n", " )\n", " ax_inset.set_xticks([])\n", " ax_inset.set_yticks([])\n", " ax_inset.set_title(\"Zoom\", fontsize=10)\n", "\n", " # Draw arrow from main plot to inset\n", " con = ConnectionPatch(\n", " xyA=zoom_coords, coordsA=ax_main.transAxes,\n", " xyB=(0.5, 0.5), coordsB=ax_inset.transAxes,\n", " axesA=ax_main, axesB=ax_inset,\n", " arrowstyle=\"->\", color=\"white\", linewidth=1.5\n", " )\n", " ax_main.add_artist(con)\n", "\n", "plt.tight_layout()\n", "plt.savefig(\"figures/examples_grid_labeled.pdf\", bbox_inches=\"tight\")\n", "plt.savefig(\"figures/examples_grid_labeled.png\", bbox_inches=\"tight\")\n", "plt.show()\n" ] }, { "cell_type": "code", "execution_count": null, "id": "db8f5681-2790-4e1e-93d0-49adfcf05679", "metadata": {}, "outputs": [], "source": [ "fname = DATA_DIR / \"ilc\" / \"ilc_weights_residuals_agora_fg_model.npy\"\n", "ilc_dict = np.load(fname, allow_pickle = True).item()\n", "total_ilc_residuals_dict = ilc_dict['total_ilc_residuals']" ] }, { "cell_type": "code", "execution_count": null, "id": "b453d3b8-f4e9-4305-8477-7b4677bafe51", "metadata": {}, "outputs": [], "source": [ "l, nl_s4_wide = total_ilc_residuals_dict['s4_wide']['mv']\n", "l, nl_s4_deep = total_ilc_residuals_dict['s4_deep']['mv']\n", "l, nl_spt3g = total_ilc_residuals_dict['spt3g']['mv']" ] }, { "cell_type": "code", "execution_count": null, "id": "9ce6651c-245b-4e58-9d13-f3678526f9b0", "metadata": {}, "outputs": [], "source": [ "n_samples = 200\n", "\n", "cls_test_cib_clean = []\n", "cls_test_tsz_clean = []\n", "cls_dm_cib_clean = []\n", "cls_dm_tsz_clean = []\n", "cls_dmu_cib_clean = []\n", "cls_dmu_tsz_clean = []\n", "\n", "cls_test_cib_noisy = []\n", "cls_test_tsz_noisy = []\n", "cls_dm_cib_noisy = []\n", "cls_dm_tsz_noisy = []\n", "cls_dmu_cib_noisy = []\n", "cls_dmu_tsz_noisy = []\n", "\n", "cls_test_cross_clean = []\n", "cls_test_cross_noisy = []\n", "cls_dm_cross_clean = []\n", "cls_dm_cross_noisy = []\n", "cls_dmu_cross_clean = []\n", "cls_dmu_cross_noisy = []" ] }, { "cell_type": "code", "execution_count": null, "id": "f66adfd6-0888-49e3-9ebd-d601825f1313", "metadata": {}, "outputs": [], "source": [ "for i in range(n_samples):\n", " noise = cl2map(flatskymapparams, nl_spt3g, l)\n", " \n", " cls_test_cib_clean.append(map2cl(flatskymapparams, test_maps[i, :, :, 0], test_maps[i, :, :, 0])[1])\n", " cls_test_tsz_clean.append(map2cl(flatskymapparams, test_maps[i, :, :, 1], test_maps[i, :, :, 1])[1])\n", " cls_test_cross_clean.append(map2cl(flatskymapparams, test_maps[i, :, :, 0], test_maps[i, :, :, 1])[1])\n", "\n", " noisy_cib = test_maps[i, :, :, 0] + noise\n", " noisy_tsz = test_maps[i, :, :, 1] + noise\n", " \n", " cls_test_cib_noisy.append(map2cl(flatskymapparams, noisy_cib, noisy_cib)[1])\n", " cls_test_tsz_noisy.append(map2cl(flatskymapparams, noisy_tsz, noisy_tsz)[1])\n", " cls_test_cross_noisy.append(map2cl(flatskymapparams, noisy_cib, noisy_tsz)[1])\n", "\n", " ###\n", " cls_dm_cib_clean.append(map2cl(flatskymapparams, dm_maps[i, :, :, 0], dm_maps[i, :, :, 0])[1])\n", " cls_dm_tsz_clean.append(map2cl(flatskymapparams, dm_maps[i, :, :, 1], dm_maps[i, :, :, 1])[1])\n", " cls_dm_cross_clean.append(map2cl(flatskymapparams, dm_maps[i, :, :, 0], dm_maps[i, :, :, 1])[1])\n", "\n", " noisy_dm_cib = dm_maps[i, :, :, 0] + noise\n", " noisy_dm_tsz = dm_maps[i, :, :, 1] + noise\n", " \n", " cls_dm_cib_noisy.append(map2cl(flatskymapparams, noisy_dm_cib, noisy_dm_cib)[1])\n", " cls_dm_tsz_noisy.append(map2cl(flatskymapparams, noisy_dm_tsz, noisy_dm_tsz)[1])\n", " cls_dm_cross_noisy.append(map2cl(flatskymapparams, noisy_dm_cib, noisy_dm_tsz)[1])\n", "\n", " ###\n", " cls_dmu_cib_clean.append(map2cl(flatskymapparams, dm_maps_unscaled[i, :, :, 0], dm_maps_unscaled[i, :, :, 0])[1])\n", " cls_dmu_tsz_clean.append(map2cl(flatskymapparams, dm_maps_unscaled[i, :, :, 1], dm_maps_unscaled[i, :, :, 1])[1])\n", " cls_dmu_cross_clean.append(map2cl(flatskymapparams, dm_maps_unscaled[i, :, :, 0], dm_maps_unscaled[i, :, :, 1])[1])\n", "\n", " noisy_dmu_cib = dm_maps_unscaled[i, :, :, 0] + noise\n", " noisy_dmu_tsz = dm_maps_unscaled[i, :, :, 1] + noise\n", " \n", " cls_dmu_cib_noisy.append(map2cl(flatskymapparams, noisy_dmu_cib, noisy_dmu_cib)[1])\n", " cls_dmu_tsz_noisy.append(map2cl(flatskymapparams, noisy_dmu_tsz, noisy_dmu_tsz)[1])\n", " cls_dmu_cross_noisy.append(map2cl(flatskymapparams, noisy_dmu_cib, noisy_dmu_tsz)[1])" ] }, { "cell_type": "code", "execution_count": null, "id": "01a4b278-dc94-40de-beb9-1fb1fd029541", "metadata": {}, "outputs": [], "source": [ "el, _ = map2cl(flatskymapparams, test_maps[0, :, :, 0])\n", "dl_factor = el * (el + 1) / (2 * np.pi)\n", "\n", "mean_test_cib = dl_factor * np.mean(cls_test_cib_clean, axis=0)\n", "mean_test_tsz = dl_factor * np.mean(cls_test_tsz_clean, axis=0)\n", "mean_test_cross = dl_factor * np.mean(cls_test_cross_clean, axis=0)\n", "\n", "mean_dm_cib = dl_factor * np.mean(cls_dm_cib_clean, axis=0)\n", "mean_dm_tsz = dl_factor * np.mean(cls_dm_tsz_clean, axis=0)\n", "mean_dm_cross = dl_factor * np.mean(cls_dm_cross_clean, axis=0)\n", "\n", "mean_dmu_cib = dl_factor * np.mean(cls_dmu_cib_clean, axis=0)\n", "mean_dmu_tsz = dl_factor * np.mean(cls_dmu_tsz_clean, axis=0)\n", "mean_dmu_cross = dl_factor * np.mean(cls_dmu_cross_clean, axis=0)\n", "\n", "std_test_cib = dl_factor * np.std(cls_test_cib_noisy, axis=0)\n", "std_test_tsz = dl_factor * np.std(cls_test_tsz_noisy, axis=0)\n", "std_test_cross = dl_factor * np.std(cls_test_cross_noisy, axis=0)\n", "\n", "bias_cross = (mean_test_cross - mean_dm_cross) / std_test_cross\n", "bias_crossu = (mean_test_cross - mean_dmu_cross) / std_test_cross\n" ] }, { "cell_type": "code", "execution_count": null, "id": "7e9cc494-63f6-4229-9517-0ef236dea49f", "metadata": {}, "outputs": [], "source": [ "plt.clf()\n", "plt.figure(figsize=(10, 5))\n", "gs = GridSpec(2, 2, height_ratios=[3, 1], hspace=0.0)\n", "\n", "colors = {'CIB': WONG[5], 'tSZ': WONG[6], 'cross': WONG[7]}\n", "fsval = 14\n", "\n", "ax1 = plt.subplot(gs[0, 0])\n", "fill_between(el, mean_test_cib - std_test_cib, mean_test_cib + std_test_cib, color=colors['CIB'], alpha=0.2)\n", "plot(el, mean_test_cib, color=colors['CIB'], lw=1, ls='-', label=\"CIB (Agora)\")\n", "plot(el, mean_dm_cib, color=colors['CIB'], lw=1, ls='--', label=\"CIB (DDPM)\")\n", "#plot(el, mean_dmu_cib, color=colors['CIB'], lw=1, ls=':', )\n", "\n", "\n", "fill_between(el, mean_test_tsz - std_test_tsz, mean_test_tsz + std_test_tsz, color=colors['tSZ'], alpha=0.2)\n", "plot(el, mean_test_tsz, color=colors['tSZ'], lw=1, ls='-', label=\"tSZ (Agora)\")\n", "plot(el, mean_dm_tsz, color=colors['tSZ'], lw=1,ls='--', label=\"tSZ (DDPM)\")\n", "#plot(el, mean_dmu_tsz, color=colors['tSZ'], lw=1,ls=':')\n", "\n", "\n", "xlim(300, 4200)\n", "ylim(-1, 30)\n", "plt.ylabel(r\"$\\ell(\\ell+1)/2\\pi\\ C_\\ell \\ \\left[\\mu K^2\\right]$\", fontsize=fsval)\n", "tick_params(axis='both', labelsize=fsval)\n", "legend(fontsize=fsval, loc='upper left', frameon=True)\n", "\n", "ax2 = plt.subplot(gs[1, 0], sharex=ax1)\n", "bias_cib = (mean_test_cib - mean_dm_cib) / std_test_cib\n", "bias_tsz = (mean_test_tsz - mean_dm_tsz) / std_test_tsz\n", "plot(el, bias_cib, color=colors['CIB'], lw=1)\n", "plot(el, bias_tsz, color=colors['tSZ'], lw=1)\n", "\n", "axhline(0, color='gray', ls='--', lw=1)\n", "xlabel(r\"Multipole $\\ell$\", fontsize=fsval)\n", "plt.ylabel(r\"$\\Delta C_\\ell / \\sigma C_\\ell$\", fontsize=fsval)\n", "ylim(-0.4, 0.4)\n", "tick_params(axis='both', labelsize=fsval)\n", "\n", "\n", "ax3 = plt.subplot(gs[0, 1])\n", "fill_between(el, mean_test_cross - std_test_cross, mean_test_cross + std_test_cross,\n", " color=colors['cross'], alpha=0.2)\n", "\n", "plot(el, mean_test_cross, color=colors['cross'], lw=1, ls='-', label=\"CIB × tSZ (Agora)\")\n", "plot(el, mean_dm_cross, color=colors['cross'], lw=1, ls='--', label=\"CIB × tSZ (DDPM)\")\n", "#plot(el, mean_dmu_cross, color=colors['cross'], lw=1, ls=':',)\n", "\n", "xlim(300, 4200)\n", "ylim(-2.9, 4)\n", "#plt.ylabel(r\"$\\ell(\\ell+1)/2\\pi\\ C_\\ell$\", fontsize=fsval)\n", "tick_params(axis='both', labelsize=fsval)\n", "legend(fontsize=fsval, loc='upper left', frameon=True)\n", "\n", "ax4 = plt.subplot(gs[1, 1], sharex=ax3)\n", "plot(el, bias_cross, color=colors['cross'], lw=1)\n", "axhline(0, color='gray', ls='--', lw=1)\n", "xlabel(r\"Multipole $\\ell$\", fontsize=fsval)\n", "ylim(-0.4, 0.4)\n", "tick_params(axis='both', labelsize=fsval)\n", "\n", "plt.tight_layout()\n", "plt.savefig(\"figures/power_spectra_comparison_cib_tsz_combined.pdf\", bbox_inches=\"tight\")\n", "plt.savefig(\"figures/power_spectra_comparison_cib_tsz_combined.png\", bbox_inches=\"tight\")\n", "plt.show()\n" ] }, { "cell_type": "code", "execution_count": null, "id": "17a5f023-d010-4afa-a032-b65fc6309613", "metadata": {}, "outputs": [], "source": [ "def smooth_hist(pixels, bins):\n", " hist, _ = np.histogram(pixels, bins=bins, density=True)\n", " return gaussian_filter1d(hist, sigma=2)\n", "\n", "pixels_training_cib = train_maps[:n_samples, :, :, 0].flatten()\n", "pixels_training_tsz = -1*train_maps[:n_samples, :, :, 1].flatten()\n", "pixels_samples_cib = dm_maps[:n_samples, :, :, 0].flatten()\n", "pixels_samples_tsz = -1*dm_maps[:n_samples, :, :, 1].flatten()\n", "pixels_samplesu_cib = dm_maps_unscaled[:n_samples, :, :, 0].flatten()\n", "pixels_samplesu_tsz = -1*dm_maps_unscaled[:n_samples, :, :, 1].flatten()\n", "\n", "#bins_cib = logspace(0., 2, 500)\n", "#bins_tsz = logspace(0., 2, 500)\n", "bins_cib = linspace(0., 70, 500)\n", "bins_tsz = linspace(0., 90, 500)\n", "\n", "bin_centers_cib = 0.5 * (bins_cib[:-1] + bins_cib[1:])\n", "train_hist_cib = smooth_hist(pixels_training_cib, bins_cib)\n", "gen_hist_cib = smooth_hist(pixels_samples_cib, bins_cib)\n", "genu_hist_cib = smooth_hist(pixels_samplesu_cib, bins_cib)\n", "\n", "bin_centers_tsz = 0.5 * (bins_tsz[:-1] + bins_tsz[1:])\n", "train_hist_tsz = smooth_hist(pixels_training_tsz, bins_tsz)\n", "gen_hist_tsz = smooth_hist(pixels_samples_tsz, bins_tsz)\n", "genu_hist_tsz = smooth_hist(pixels_samplesu_tsz, bins_tsz)" ] }, { "cell_type": "code", "execution_count": null, "id": "cd1c52f3-923b-47e0-9b07-3103f622347b", "metadata": { "jupyter": { "source_hidden": true } }, "outputs": [], "source": [ "plt.clf()\n", "fig, axs = subplots(2, 2, figsize=(8, 6), sharex='col', sharey='row', gridspec_kw={'wspace': 0, 'hspace': 0})\n", "\n", "fsval = 14\n", "colors = {'DM': WONG[5], 'train': 'black'}\n", "\n", "axs[0, 0].plot(bin_centers_cib, train_hist_cib, color=colors['train'], lw=1, label='Agora')\n", "axs[0, 0].plot(bin_centers_cib, gen_hist_cib, color=colors['DM'], lw=1, label='DDPM')\n", "axs[0, 0].set_ylabel(\"Normalized counts\", fontsize=fsval)\n", "axs[0, 0].set_title(\"CIB one-point PDF\", fontsize=fsval)\n", "axs[0, 0].legend(fontsize=fsval)\n", "axs[0, 0].tick_params(labelbottom=False)\n", "\n", "axs[0, 1].plot(bin_centers_tsz, train_hist_tsz, color=colors['train'], lw=1, label='Agora')\n", "axs[0, 1].plot(bin_centers_tsz, gen_hist_tsz, color=colors['DM'], lw=1, label='DDPM')\n", "axs[0, 1].set_title(\"tSZ one-point PDF\", fontsize=fsval)\n", "axs[0, 1].tick_params(labelbottom=False, labelleft=False)\n", "\n", "axs[1, 0].plot(bin_centers_cib, train_hist_cib, color=colors['train'], lw=1)\n", "axs[1, 0].plot(bin_centers_cib, gen_hist_cib, color=colors['DM'], lw=1)\n", "axs[1, 0].set_yscale('log')\n", "axs[1, 0].set_xlabel(\"Pixel Intensity\", fontsize=fsval)\n", "axs[1, 0].set_ylabel(\"Normalized counts\", fontsize=fsval)\n", "\n", "axs[1, 1].plot(bin_centers_tsz, train_hist_tsz, color=colors['train'], lw=1)\n", "axs[1, 1].plot(bin_centers_tsz, gen_hist_tsz, color=colors['DM'], lw=1)\n", "axs[1, 1].set_yscale('log')\n", "axs[1, 1].set_xlabel(\"Pixel Intensity\", fontsize=fsval)\n", "axs[1, 1].tick_params(labelleft=False)\n", "\n", "\n", "plt.tight_layout()\n", "plt.savefig(\"figures/histogram_comparison_curve_2panel.pdf\", bbox_inches=\"tight\")\n", "plt.savefig(\"figures/histogram_comparison_curve_2panel.png\", bbox_inches=\"tight\")\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "45d88a8a-a174-4c4d-99e9-8066a74fac35", "metadata": {}, "outputs": [], "source": [ "plt.clf()\n", "fig, axs = subplots(1, 2, figsize=(8, 3), sharey=True, gridspec_kw={'wspace': 0})\n", "\n", "fsval = 14\n", "colors = {'DM': 'steelblue', 'train': 'black','Unscaled DM': 'red'}\n", "\n", "axs[0].plot(-bin_centers_tsz, train_hist_tsz, color=colors['train'], lw=1.5, label='Agora')\n", "axs[0].plot(-bin_centers_tsz, gen_hist_tsz, color=colors['DM'], lw=1.5, label='DDPM')\n", "axs[0].plot(-bin_centers_tsz, genu_hist_tsz, color=colors['Unscaled DM'], lw=1, alpha=0.7, ls='--', label='Unscaled DDPM')\n", "axs[0].set_yscale('log')\n", "axs[0].set_xlabel(r\"Pixel Intensity [$\\mu K$]\", fontsize=fsval)\n", "axs[0].set_ylabel(\"Normalized counts\", fontsize=fsval)\n", "axs[0].set_ylim(5e-9,5e-1)\n", "axs[0].set_xlim(-90,0)\n", "\n", "axs[0].set_title(\"tSZ one-point PDF\", fontsize=fsval)\n", "axs[0].legend(fontsize=fsval)\n", "\n", "axs[1].plot(bin_centers_cib, train_hist_cib, color=colors['train'], lw=1.5)\n", "axs[1].plot(bin_centers_cib, gen_hist_cib, color=colors['DM'], lw=1.5)\n", "axs[1].plot(bin_centers_cib, genu_hist_cib, color=colors['Unscaled DM'], lw=1, alpha=0.7, ls=\"--\")\n", "axs[1].set_yscale('log')\n", "axs[1].set_ylim(1e-7,4e-1)\n", "axs[1].set_xlim(0,70)\n", "\n", "axs[1].set_xlabel(r\"Pixel Intensity [$\\mu K$]\", fontsize=fsval)\n", "axs[1].set_title(\"CIB one-point PDF\", fontsize=fsval)\n", "axs[1].tick_params(labelleft=False)\n", "\n", "plt.tight_layout()\n", "savefig(\"figures/histogram_comparison.pdf\", bbox_inches=\"tight\")\n", "savefig(\"figures/histogram_comparison.png\", bbox_inches=\"tight\")\n", "plt.show()\n" ] }, { "cell_type": "markdown", "id": "052c9c88-1610-49b6-9fef-276e45a402b9", "metadata": {}, "source": [ "## Moments" ] }, { "cell_type": "code", "execution_count": null, "id": "07d9440e-1db9-42d7-9507-c739122c9d23", "metadata": {}, "outputs": [], "source": [ "labels = [\"Bispectrum $S_{3}$ [$\\mu$K$^{3}$]\", \"Trispectrum $S_{4}$ [$\\mu$K$^{4}$]\"]\n", "title_labels = [\n", " r\"$S_3$\",r\"$S_4$\",\n", "]\n", "\n", "\n", "moment_keys = [\"moment_03\", \"moment_07\"]\n", "mul_fac_dic = {0: 1e4, 1: 1e6}\n", "\n", "noise_levels = [\"s4deep_noise\"]\n", "\n", "models = [\"train\", \"diffusion\", \"gaussian\"]\n", "model_labels = [\"Agora\", \"DDPM\", \"Gaussian\"]\n", "markers = {\"train\": \".\", \"diffusion\": \".\", \"gaussian\": \".\"}\n", "\n", "noise_colors = {\n", " #\"spt3g_noise\": \"goldenrod\",\n", " #\"s4wide_noise\": \"steelblue\",\n", " \"s4deep_noise\": \"mediumvioletred\"\n", "}\n", "\n", "colors = {\n", " \"gaussian\": \"orangered\",\n", " \"train\": \"black\",\n", " \"diffusion\": \"steelblue\"\n", "}" ] }, { "cell_type": "code", "execution_count": null, "id": "6a026770-961a-4232-a947-68311d30cf89", "metadata": {}, "outputs": [], "source": [ "lmin = 300\n", "lmax = 6000\n", "delta_l = 720\n", "bp_edges = np.arange(lmin, lmax, delta_l)\n", "bandpass_centers = (bp_edges[:-1] + bp_edges[1:]) // 2" ] }, { "cell_type": "code", "execution_count": null, "id": "e5ca5814-a4b1-478d-9b08-919a93478948", "metadata": {}, "outputs": [], "source": [ "def load_moment_sum(model, noise=None):\n", " tag = \"samples_curve\" if model == \"diffusion\" else model\n", " suffix = f\"_sum_{noise}\" if noise else \"_sum\"\n", " fname = PATCHES_DIR / f\"moments_{tag}_2mJy_deltaell_720_200rlz_6x6_lp_joint3{suffix}.npy\"\n", " return load_all_moments(fname, bandpass_centers)\n", "\n", "sum_data_clean = {\n", " model: load_moment_sum(model)\n", " for model in models\n", "}\n", "\n", "sum_data_noisy = {\n", " noise: {\n", " model: load_moment_sum(model, noise)\n", " for model in models\n", " }\n", " for noise in noise_levels\n", "}" ] }, { "cell_type": "code", "execution_count": null, "id": "fe46a5dd-c712-4569-92fe-857f135048fa", "metadata": {}, "outputs": [], "source": [ "plt.clf()\n", "\n", "tr, tc = 1, 2\n", "plt.figure(figsize=(8., 4.2))\n", "plt.subplots_adjust(wspace=0.1)\n", "\n", "alphaval = 1.\n", "mewval = 0.8\n", "fsval = 14\n", "msval = 8.\n", "capsizeval = 0.\n", "\n", "offset_step = 0.\n", "model_offset_step = 50\n", "bpwidth = np.diff(bandpass_centers)[0]\n", "\n", "for idx, (label, moment_key) in enumerate(zip(title_labels, moment_keys)):\n", " ax = plt.subplot(tr, tc, idx + 1)\n", "\n", " for noise_i, noise_level in enumerate(noise_levels):\n", " base_offset = (noise_i - len(noise_levels) // 2) * offset_step\n", "\n", " for model_i, model in enumerate(markers):\n", " color = colors[model]\n", " mean_data = np.array(sum_data_clean[model][moment_key])\n", " noisy_data = np.array(sum_data_noisy[noise_level][model][moment_key])\n", "\n", " mean_vals = mean(mean_data, axis=0)\n", " std_errs = std(noisy_data, axis=0) / np.sqrt(noisy_data.shape[0])\n", " offset = base_offset + (model_i - len(models) // 2) * model_offset_step\n", " shifted_x = bandpass_centers[1:] + offset\n", "\n", " errorbar(shifted_x, mean_vals[1:]* mul_fac_dic[idx], yerr=std_errs[1:]* mul_fac_dic[idx],\n", " fmt=markers[model], capsize=capsizeval, color=color, alpha=alphaval,\n", " label=model_labels[model_i], ms=msval, lw=mewval)\n", "\n", " if idx == 0:\n", " plt.ylabel(labels[0], fontsize=fsval)\n", " #ylim(-8e-4, 0.5e-4)\n", " axhline(0., lw=1., alpha=0.5)\n", " elif idx == 1:\n", " plt.ylabel(labels[1], fontsize=fsval)\n", " legend(fontsize=12, numpoints=1, handletextpad=-0.2)\n", " #ylim(-0.5e-6, 8e-6)\n", " axhline(0., lw=1., alpha=0.5)\n", "\n", " #title(label, fontsize=fsval)\n", " xlim(bandpass_centers[1] - bpwidth/2 - 20, max(bandpass_centers) + bpwidth/2 + 30)\n", " xlabel(r\"Multipole $\\ell_c$\", fontsize=fsval)\n", " grid(True, which='both', ls='-', lw=0.2, alpha=0.2)\n", "\n", " for cntr in range(len(bandpass_centers)):\n", " if cntr % 2 != 0:\n", " x1 = bandpass_centers[cntr] - bpwidth / 2\n", " x2 = bandpass_centers[cntr] + bpwidth / 2\n", " axvspan(x1, x2, color='peru', alpha=0.2)\n", " \n", " label_with_mulfac = r'%s [$10^{%g}$]' %(label, np.log10(mul_fac_dic[idx]))\n", " ax.set_title(label_with_mulfac, fontsize=fsval)\n", "\n", " grid(False, which='both', axis = 'both')\n", "plt.tight_layout()\n", "savefig(\"/global/homes/k/kp22/ddpm_paper/figures/moments_sum_cib_tsz.pdf\", bbox_inches=\"tight\")\n", "savefig(\"/global/homes/k/kp22/ddpm_paper/figures/moments_sum_cib_tsz.png\", bbox_inches=\"tight\")\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "338ba49d-e7c7-4129-91aa-a15839397de8", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "81a72f18-e115-41f9-8e9e-74fab85dd93e", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "a01b61a7-3c33-4d87-ab9d-76b6cfe92733", "metadata": {}, "outputs": [], "source": [ "noise_levels = [\"spt3g_noise\", \"s4wide_noise\", \"s4deep_noise\"]\n", "models = [\"gaussian\", \"train\", \"diffusion\"]\n", "\n", "base_path = PATCHES_DIR / \"\"\n", "file_template_clean = base_path + \"moments_{tag}_2mJy_deltaell_720_200rlz_6x6_lp_joint3.npy\"\n", "file_template_noisy = base_path + \"moments_{tag}_2mJy_deltaell_720_200rlz_6x6_lp_joint3_{noise}.npy\"\n", "\n", "means_clean = {\n", " \"train\": load_all_moments(file_template_clean.format(tag=\"train\"), bandpass_centers),\n", " \"diffusion\": load_all_moments(file_template_clean.format(tag=\"samples_curve\"), bandpass_centers),\n", " \"gaussian\": load_all_moments(file_template_clean.format(tag=\"gaussian\"), bandpass_centers),\n", "}\n", "\n", "errors_noisy_all = {\n", " noise: {\n", " \"train\": load_all_moments(file_template_noisy.format(tag=\"train\", noise=noise), bandpass_centers),\n", " \"diffusion\": load_all_moments(file_template_noisy.format(tag=\"samples_curve\", noise=noise), bandpass_centers),\n", " \"gaussian\": load_all_moments(file_template_noisy.format(tag=\"gaussian\", noise=noise), bandpass_centers),\n", " }\n", " for noise in noise_levels\n", "}" ] }, { "cell_type": "code", "execution_count": null, "id": "6faf801a-90df-44c1-a7a5-5f02c6de44ee", "metadata": {}, "outputs": [], "source": [ "exclude_n = 1\n", "fsval = 16\n", "offset_step = 100\n", "\n", "labels = [\n", " r\"$S_3^{aaa}$\", r\"$S_3^{bbb}$\", r\"$S_3^{aab}$\", r\"$S_3^{abb}$\",\n", " r\"$S_4^{aaaa}$\", r\"$S_4^{bbbb}$\", r\"$S_4^{aaab}$\", r\"$S_4^{aabb}$\", r\"$S_4^{abbb}$\"\n", "]\n", "\n", "colors = {\"gaussian\": \"orangered\", \"train\": \"black\", \"diffusion\": \"steelblue\"}\n", "markers = {\"s4deep_noise\": \"o\", \"s4wide_noise\": \"^\", \"spt3g_noise\": \"*\"}\n", "\n", "bpwidth = np.diff(bandpass_centers)[0]\n", "\n", "plt.clf()\n", "fig, axes = subplots(3, 3, figsize=(16, 12), sharex=True)\n", "axes = axes.flatten()\n", "\n", "for idx, label in enumerate(labels):\n", " ax = axes[idx]\n", " moment_key = f\"moment_{3 + idx:02d}\"\n", "\n", " for noise_i, noise_level in enumerate(noise_levels):\n", " offset = (noise_i - len(noise_levels) // 2) * offset_step\n", "\n", " for model in models:\n", " mean_vals = mean(means_clean[model][moment_key], axis=0)\n", " std_errs = std(errors_noisy_all[noise_level][model][moment_key], axis=0) / np.sqrt(len(errors_noisy_all[noise_level][model][moment_key]))\n", "\n", " shifted_x = bandpass_centers[exclude_n:] + offset\n", " label_str = f\"{model.capitalize()}\" if idx == 0 and noise_i == 0 else None\n", "\n", " ax.errorbar(shifted_x, mean_vals[exclude_n:], yerr=std_errs[exclude_n:], \n", " fmt=markers[noise_level], capsize=2, color=colors[model],\n", " alpha=0.5, label=label_str)\n", "\n", " for cntr in range(len(bandpass_centers)):\n", " if cntr % 2 != 0:\n", " x1 = bandpass_centers[cntr] - bpwidth / 2\n", " x2 = bandpass_centers[cntr] + bpwidth / 2\n", " ax.axvspan(x1, x2, color='peru', alpha=0.2)\n", "\n", " ax.set_title(label, fontsize=fsval)\n", " ax.grid(True, which='both', ls='-', lw=0.2, alpha=0.2)\n", " if idx // 3 == 2:\n", " ax.set_xlabel(r\"Multipole $\\ell_c$\", fontsize=fsval)\n", " if idx % 3 == 0:\n", " ax.set_ylabel(\"Statistic value\", fontsize=fsval)\n", "\n", "legend_elements = (\n", " [Line2D([0], [0], color=colors[m], lw=4, label=m.capitalize()) for m in models] +\n", " [Line2D([0], [0], marker=markers[n], color='black', linestyle='None', markersize=8, label=n.replace('_', '-')) for n in markers] +\n", " [Line2D([], [], color='none', label=r\"$a$ = CIB,\\ $b$ = tSZ\")]\n", ")\n", "fig.legend(handles=legend_elements, fontsize=14, loc='lower center',\n", " bbox_to_anchor=(0.5, -0.05), ncol=len(legend_elements), frameon=False)\n", "\n", "plt.tight_layout()\n", "plt.savefig(\"figures/moments_joint_cib_tsz.pdf\", bbox_inches=\"tight\")\n", "plt.savefig(\"figures/moments_joint_cib_tsz.png\", bbox_inches=\"tight\")\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "e3a2dfc3-f60e-45ee-be13-657be5b0e05d", "metadata": {}, "outputs": [], "source": [ "exclude_n = 1\n", "fsval = 16\n", "offset_step = 100\n", "\n", "labels = [\n", " #r\"$S_3^{aaa}$\", r\"$S_3^{aab}$\", r\"$S_3^{abb}$\", r\"$S_3^{bbb}$\",\n", " r\"$S_3^{aaa}$\", r\"$S_3^{bbb}$\", r\"$S_3^{aab}$\", r\"$S_3^{abb}$\",\n", "]\n", "\n", "colors = {\"gaussian\": \"orangered\", \"train\": \"black\", \"diffusion\": \"steelblue\"}\n", "markers = {\"s4deep_noise\": \"o\", \"s4wide_noise\": \"^\", \"spt3g_noise\": \"*\"}\n", "expname_dic = {\"s4deep_noise\": r'S4-Ultra deep', \"s4wide_noise\": r\"S4-Wide\", \"spt3g_noise\": r\"SPT-3G\"}\n", "model_labels = {\"gaussian\": \"Gaussian\", \"train\": \"Agora\", \"diffusion\": \"DDPM\"}\n", "\n", "bpwidth = np.diff(bandpass_centers)[0]\n", "\n", "tr, tc = 2, 2\n", "plt.clf()\n", "fig, axes = subplots(tr, tc, figsize=(8, 6.5), sharex=True)\n", "axes = axes.flatten()\n", "plt.subplots_adjust(wspace = 0.02)\n", "\n", "mul_fac_dic = {0: 1e4, 1: 1e4, 2: 1e5, 3: 1e5}\n", "\n", "for idx, label in enumerate(labels):\n", " ax = axes[idx]\n", " moment_key = f\"moment_{3 + idx:02d}\"\n", " print(moment_key)\n", " curr_mul_fac = mul_fac_dic[idx]\n", "\n", " for noise_i, noise_level in enumerate(noise_levels):\n", " offset = (noise_i - len(noise_levels) // 2) * offset_step\n", "\n", " for model in models:\n", " mean_vals = mean(means_clean[model][moment_key], axis=0)\n", " std_errs = std(errors_noisy_all[noise_level][model][moment_key],axis=0) / np.sqrt(len(errors_noisy_all[noise_level][model][moment_key]))\n", "\n", " shifted_x = bandpass_centers[exclude_n:] + offset\n", " label_str = f\"{model_labels[model]}\" if idx == 0 and noise_i == 0 else None\n", "\n", " ax.errorbar(shifted_x, mean_vals[exclude_n:] * curr_mul_fac, yerr=std_errs[exclude_n:]* curr_mul_fac, \n", " fmt=markers[noise_level], capsize=2, color=colors[model],\n", " alpha=0.5, \n", " #label=label_str, \n", " ms = 4.)\n", "\n", "\n", " for cntr in range(len(bandpass_centers)):\n", " if cntr % 2 != 0:\n", " x1 = bandpass_centers[cntr] - bpwidth / 2\n", " x2 = bandpass_centers[cntr] + bpwidth / 2\n", " ax.axvspan(x1, x2, color='peru', alpha=0.2)\n", "\n", " ax.axhline(0., lw = 0.5, color = 'black')\n", "\n", " label_with_mulfac = r'%s [$10^{%g}$]' %(label, np.log10(curr_mul_fac))\n", " ax.set_title(label_with_mulfac, fontsize=fsval)\n", " ax.grid(False, which='both', axis = 'both')#ls='-', lw=0.2, alpha=0.2)\n", " if idx >= tc:\n", " ax.set_xlabel(r\"Multipole $\\ell_c$\", fontsize=fsval)\n", " if idx % tc == 0:\n", " #ax.set_ylabel(\"Statistic value\", fontsize=fsval)\n", " ax.set_ylabel(\"Bispectra $S_{3}$ [$\\mu$K$^{3}$]\", fontsize=fsval)\n", "\n", " if idx == 0: #add legend for models\n", " for model in models:\n", " print(model)\n", " label_str = f\"{model_labels[model]}\"\n", " ax.errorbar([], [], yerr=[0], \n", " fmt='o', \n", " capsize=2, \n", " color=colors[model],\n", " alpha=0.5, \n", " label=label_str, \n", " ms = 4.)\n", " ax.legend(loc = 1, fontsize = fsval - 3, numpoints = 1)\n", " ax.text(4000, 1.75, r'$a = {\\rm CIB}$')\n", " ax.text(4000, 1.65, r'$b = {\\rm tSZ}$') \n", " if idx == 1: #add legend for models\n", " for expname in expname_dic:\n", " label_str = r'%s' %(expname_dic[expname])\n", " ax.errorbar([], [], yerr=[0], \n", " capsize=2, \n", " color='black',\n", " fmt=markers[expname],\n", " alpha=0.5, \n", " label=label_str, \n", " ms = 4.,\n", " )\n", " leg = ax.legend( fontsize = fsval - 3, title = r'{\\bf Noise level:}', \n", " title_fontsize = fsval - 3, numpoints = 1)\n", " leg._legend_box.align = \"left\"\n", "\n", "plt.tight_layout()\n", "savefig(\"figures/bispectra_joint_cib_tsz.pdf\", bbox_inches=\"tight\")\n", "savefig(\"figures/bispectra_joint_cib_tsz.png\", bbox_inches=\"tight\")\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "2215f8fe-4d12-4069-9173-4348cef59b33", "metadata": {}, "outputs": [], "source": [ "exclude_n = 1\n", "fsval = 13\n", "offset_step = 100\n", "\n", "subplot_panel_dic = {\n", " #value for each key is (rowno, colno), rowspan, colspan, mulfac. \n", " r\"$S_4^{aaaa}$\": [(0, 0), 3, 1, 1e7], \n", " r\"$S_4^{bbbb}$\": [(3, 0), 3, 1, 1e6], \n", " r\"$S_4^{aaab}$\": [(0, 1), 2, 1, 1e8], \n", " r\"$S_4^{aabb}$\": [(2, 1), 2, 1, 1e8], \n", " r\"$S_4^{abbb}$\": [(4, 1), 2, 1, 1e7],\n", "}\n", "\n", "colors = {\"gaussian\": \"orangered\", \"train\": \"black\", \"diffusion\": \"steelblue\"}\n", "markers = {\"s4deep_noise\": \"o\", \"s4wide_noise\": \"^\", \"spt3g_noise\": \"*\"}\n", "expname_dic = {\"s4deep_noise\": r'S4-Ultra deep', \"s4wide_noise\": r\"S4-Wide\", \"spt3g_noise\": r\"SPT-3G\"}\n", "\n", "bpwidth = np.diff(bandpass_centers)[0]\n", "\n", "\n", "plt.clf()\n", "tr, tc = 6, 2\n", "fig = plt.figure(figsize=(9.2, 9.2))\n", "plt.subplots_adjust(wspace = 0.15, hspace = 0.4)\n", "\n", "for idx,label in enumerate(subplot_panel_dic):\n", " curr_sbpl, curr_rowspan, curr_colspan, curr_mul_fac = subplot_panel_dic[label]\n", " curr_row, curr_col = curr_sbpl\n", " ax = subplot2grid( (tr, tc), curr_sbpl, colspan = curr_colspan, rowspan = curr_rowspan)\n", "\n", " moment_key = f\"moment_{7 + idx:02d}\"\n", " print(moment_key)\n", " for noise_i, noise_level in enumerate(noise_levels):\n", " offset = (noise_i - len(noise_levels) // 2) * offset_step\n", "\n", " for model in models:\n", " mean_vals = mean(means_clean[model][moment_key], axis=0)\n", " std_errs = std(errors_noisy_all[noise_level][model][moment_key], axis=0) / np.sqrt(len(errors_noisy_all[noise_level][model][moment_key]))\n", "\n", " shifted_x = bandpass_centers[exclude_n:] + offset\n", " label_str = f\"{model.capitalize()}\" if idx == 0 and noise_i == 0 else None\n", "\n", " ax.errorbar(shifted_x, mean_vals[exclude_n:] * curr_mul_fac, yerr=std_errs[exclude_n:]* curr_mul_fac, \n", " fmt=markers[noise_level], capsize=2, color=colors[model],\n", " alpha=0.5, \n", " #label=label_str, \n", " ms = 4.)\n", "\n", " for cntr in range(len(bandpass_centers)):\n", " if cntr % 2 != 0:\n", " x1 = bandpass_centers[cntr] - bpwidth / 2\n", " x2 = bandpass_centers[cntr] + bpwidth / 2\n", " ax.axvspan(x1, x2, color='peru', alpha=0.2)\n", "\n", " ax.axhline(0., lw = 0.5, color = 'black')\n", "\n", " label_with_mulfac = r'%s [$10^{%g}$]' %(label, np.log10(curr_mul_fac))\n", " ax.set_title(label_with_mulfac, fontsize=fsval)\n", " ax.grid(False, which='both', axis = 'both')#ls='-', lw=0.2, alpha=0.2)\n", " if curr_row in [3, 4]:\n", " ax.set_xlabel(r\"Multipole $\\ell_c$\", fontsize=fsval)\n", " else:\n", " setp(ax.get_xticklabels(), visible=False); \n", " if curr_col == 0:\n", " ax.set_ylabel(\"Trispectra $S_{4}$ [$\\mu$K$^{4}$]\", fontsize=fsval)\n", "\n", " for label in ax.get_xticklabels(): label.set_fontsize(fsval-2)\n", " for label in ax.get_yticklabels(): label.set_fontsize(fsval-2)\n", "\n", " if curr_sbpl in [(0,0)]: #add legend for models\n", " for model in models:\n", " print(model)\n", " label_str = f\"{model_labels[model]}\"\n", " ax.errorbar([], [], yerr=[0], \n", " fmt='o', \n", " capsize=2, \n", " color=colors[model],\n", " alpha=0.5, \n", " label=label_str, \n", " ms = 4.)\n", " ax.legend( fontsize = fsval - 2, numpoints = 1)\n", " ax.text(4000,2.65, r'$a = {\\rm CIB}$')\n", " ax.text(4000,2.5, r'$b = {\\rm tSZ}$')\n", " \n", " if curr_sbpl in [(3,0)]: #add legend for experiments\n", " for expname in expname_dic:\n", " label_str = r'%s' %(expname_dic[expname])\n", " ax.errorbar([], [], yerr=[0], \n", " capsize=2, \n", " color='black',\n", " fmt=markers[expname],\n", " alpha=0.5, \n", " label=label_str, \n", " ms = 4.,\n", " )\n", " leg = ax.legend( fontsize = fsval - 2, title = r'{\\bf Noise level:}', \n", " title_fontsize = fsval - 2, numpoints = 1)\n", " leg._legend_box.align = \"left\"\n", "\n", "savefig(\"figures/trispectra_joint_cib_tsz.pdf\", bbox_inches=\"tight\")\n", "savefig(\"figures/trispectra_joint_cib_tsz.png\", bbox_inches=\"tight\")\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "4904ba1e-5b32-4c20-8bec-426c6e7e6687", "metadata": {}, "source": "## Minkowski Functionals" }, { "cell_type": "code", "execution_count": null, "id": "381ddd76-9507-43a9-99df-ce7723b31d19", "metadata": {}, "outputs": [], "source": [ "loaded = np.load(PATCHES_DIR / \"minkowski_results.npz\", allow_pickle=True)\n", "results = {k: loaded[k].tolist() for k in loaded}\n", "thresholds = np.linspace(0., 1, 50) " ] }, { "cell_type": "code", "execution_count": null, "id": "597dab83-2599-4367-8af9-73d99db6369d", "metadata": {}, "outputs": [], "source": [ "fsval = 16 \n", "fig, ax = plt.subplots(3, 2, figsize=(10, 8))\n", "\n", "titles = [\"M0 (Area)\", \"M1 (Perimeter)\", \"M2 (Genus)\"]\n", "column_titles = [\"CIB\", \"tSZ\"]\n", "\n", "for i, stat in enumerate(['M0', 'M1', 'M2']):\n", " for j, label in enumerate(['cib', 'tsz']):\n", " a = ax[i, j]\n", " train_data = results[f'{stat}_train_{label}']\n", " samples_data = results[f'{stat}_samples_{label}']\n", " gaussian_data = results[f'{stat}_gaussian_{label}']\n", "\n", " train_means = [x[0] for x in train_data]\n", " train_stds = [x[1] for x in train_data]\n", " samples_means = [x[0] for x in samples_data]\n", " samples_stds = [x[1] for x in samples_data]\n", " gaussian_means = [x[0] for x in gaussian_data]\n", " gaussian_stds = [x[1] for x in gaussian_data]\n", "\n", " #if label=='tsz':\n", " #a.set_xscale('log')\n", " #if label == 'tsz':\n", " # a.set_xlim(0.8, 1)\n", "\n", " if stat == 'M0':\n", " a.set_yscale('log')\n", "\n", " a.plot(thresholds, train_means, label='Agora', color='black', lw=3, alpha=0.7)\n", " a.fill_between(thresholds, np.array(train_means) - np.array(train_stds),\n", " np.array(train_means) + np.array(train_stds), alpha=0.3, color='black')\n", " a.plot(thresholds, samples_means, label='DDPM', color='steelblue', lw=1)\n", " a.fill_between(thresholds, np.array(samples_means) - np.array(samples_stds),\n", " np.array(samples_means) + np.array(samples_stds), alpha=0.3, color='steelblue')\n", " a.plot(thresholds, gaussian_means, label='Gaussian', color=WONG[6], lw=1)\n", " a.fill_between(thresholds, np.array(gaussian_means) - np.array(gaussian_stds),\n", " np.array(gaussian_means) + np.array(gaussian_stds), alpha=0.3, color=WONG[6])\n", "\n", " if i == 2:\n", " a.set_xlabel(r\"Threshold $\\nu$\", fontsize=fsval+2)\n", "\n", " if j == 0:\n", " a.set_ylabel(titles[i], fontsize=fsval+2)\n", "\n", " if i == 0:\n", " a.set_title(column_titles[j], fontsize=fsval + 2)\n", "\n", " if i == 0 and j == 1:\n", " a.legend(fontsize=fsval)\n", "\n", "plt.tight_layout()\n", "plt.savefig(\"figures/minkowski.pdf\", bbox_inches=\"tight\")\n", "plt.savefig(\"figures/minkowski.png\", bbox_inches=\"tight\")\n", "plt.show()\n" ] }, { "cell_type": "code", "execution_count": null, "id": "03a5da9a-d2be-4127-9dd3-89184ff8b3d5", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "6cacc65c-e82d-4a13-876c-34e935946562", "metadata": {}, "outputs": [], "source": [ "fig, ax = plt.subplots(3, 2, figsize=(14, 12), sharex=True)\n", "titles = [\"M0 (Area)\", \"M1 (Perimeter)\", \"M2 (Genus)\"]\n", "\n", "for i, stat in enumerate(['M0', 'M1', 'M2']):\n", " for j, label in enumerate(['cib', 'tsz']):\n", " a = ax[i, j]\n", " train_data = results[f'{stat}_train_{label}']\n", " samples_data = results[f'{stat}_samples_{label}']\n", " gaussian_data = results[f'{stat}_gaussian_{label}']\n", "\n", " train_means = [x[0] for x in train_data]\n", " train_stds = [x[1] for x in train_data]\n", " samples_means = [x[0] for x in samples_data]\n", " samples_stds = [x[1] for x in samples_data]\n", " gaussian_means = [x[0] for x in gaussian_data]\n", " gaussian_stds = [x[1] for x in gaussian_data]\n", "\n", " if stat == 'M0':\n", " a.set_yscale('log')\n", " a.plot(thresholds, train_means, label='Agora', color='black', lw=2)\n", " a.fill_between(thresholds, np.array(train_means) - np.array(train_stds),\n", " np.array(train_means) + np.array(train_stds), alpha=0.3, color='black')\n", " a.plot(thresholds, samples_means, label='DDPM', color='steelblue', lw=2)\n", " a.fill_between(thresholds, np.array(samples_means) - np.array(samples_stds),\n", " np.array(samples_means) + np.array(samples_stds), alpha=0.3, color='steelblue')\n", " a.plot(thresholds, gaussian_means, label='Gaussian', color=WONG[6], lw=2)\n", " a.fill_between(thresholds, np.array(gaussian_means) - np.array(gaussian_stds),\n", " np.array(gaussian_means) + np.array(gaussian_stds), alpha=0.3, color=WONG[6])\n", "\n", " if i == 2:\n", " a.set_xlabel(r\"Threshold $\\nu$\",fontsize=fsval)\n", " #else:\n", " #a.set_xticklabels([])\n", "\n", " if j == 0:\n", " a.set_ylabel(titles[i],fontsize=fsval)\n", " #else:\n", " #a.set_yticklabels([])\n", "\n", " if i == 0 and j == 1:\n", " a.legend(fontsize=fsval)\n", "\n", "#plt.subplots_adjust(wspace=0, hspace=0)\n", "plt.savefig(\"figures/minkowski.pdf\", bbox_inches=\"tight\")\n", "plt.savefig(\"figures/minkowski.png\", bbox_inches=\"tight\")\n", "plt.show()\n" ] }, { "cell_type": "code", "execution_count": null, "id": "4ab6722d-ffef-49a6-936b-0bff8a51b52d", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "denoising_diffusion_pytorch", "language": "python", "name": "denoising_diffusion_pytorch" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.2" } }, "nbformat": 4, "nbformat_minor": 5 }