{ "cells": [ { "cell_type": "markdown", "id": "b4265954", "metadata": {}, "source": [ "# APC MNIST Training Example\n", "\n", "This notebook is a compact, notebook-first APC training walkthrough for MNIST.\n", "\n", "Design choices are intentionally fixed to keep code short:\n", "\n", "- Binomial data distribution\n", "- Conv-PC encoder\n", "- NN decoder\n", "- AdamW optimizer\n", "- MultiStepLR scheduler (milestones at 66% and 90%)\n", "- Training progress printed only at 10%, 20%, ..., 100%\n", "- Inline metrics and reconstructions only (no checkpoint/JSON/image file outputs)\n" ] }, { "cell_type": "markdown", "id": "9a85ca9a", "metadata": {}, "source": [ "## Imports\n", "\n", "We keep imports explicit so you can quickly see which pieces handle modeling, training, and visualization.\n", "\n", "- Standard library: lightweight config and serialization helpers\n", "- PyTorch + Lightning Fabric: tensor ops and runtime orchestration\n", "- SPFlow APC modules: encoder/decoder/model/loss tooling\n", "- Torchvision + plotting stack: data loading and inline analysis\n" ] }, { "cell_type": "code", "id": "3b3ba0b6", "metadata": { "execution": { "iopub.execute_input": "2026-03-04T14:46:20.912698Z", "iopub.status.busy": "2026-03-04T14:46:20.912460Z", "iopub.status.idle": "2026-03-04T14:46:22.455980Z", "shell.execute_reply": "2026-03-04T14:46:22.455479Z" }, "ExecuteTime": { "end_time": "2026-03-04T15:04:07.772178Z", "start_time": "2026-03-04T15:04:07.747505Z" } }, "source": [ "from __future__ import annotations\n", "\n", "# Standard-library helpers for config handling and lightweight reporting.\n", "import json\n", "import math\n", "from dataclasses import asdict, dataclass\n", "from pathlib import Path\n", "\n", "# Core scientific stack used throughout the notebook.\n", "import matplotlib.pyplot as plt\n", "import pandas as pd\n", "import torch\n", "from torch import nn\n", "from torch.optim import AdamW, Optimizer\n", "from torch.optim.lr_scheduler import MultiStepLR\n", "from torch.utils.data import DataLoader, random_split\n", "\n", "# APC/SPFlow model components.\n", "from spflow.modules.leaves import Binomial, Normal\n", "from spflow.modules.leaves.leaf import LeafModule\n", "from spflow.zoo.apc.config import ApcConfig, ApcLossWeights\n", "from spflow.zoo.apc.decoders import NeuralDecoder2D\n", "from spflow.zoo.apc.encoders.convpc_joint_encoder import ConvPcJointEncoder\n", "from spflow.zoo.apc.model import AutoencodingPC\n", "from spflow.zoo.apc.train import evaluate_apc\n", "\n", "# Runtime orchestration and dataset transforms/utilities.\n", "import lightning as L\n", "from torchvision import datasets, transforms\n", "from torchvision.utils import make_grid\n" ], "outputs": [], "execution_count": 35 }, { "cell_type": "markdown", "id": "e5ccebb3", "metadata": {}, "source": [ "## Defaults\n", "\n", "A single dataclass holds all fixed tutorial settings so runs stay reproducible and easy to modify.\n", "\n", "These values prioritize a fast, stable guide run over full hyperparameter tuning.\n" ] }, { "cell_type": "code", "id": "a4586299", "metadata": { "execution": { "iopub.execute_input": "2026-03-04T14:46:22.457243Z", "iopub.status.busy": "2026-03-04T14:46:22.457135Z", "iopub.status.idle": "2026-03-04T14:46:22.461588Z", "shell.execute_reply": "2026-03-04T14:46:22.461193Z" }, "ExecuteTime": { "end_time": "2026-03-04T15:04:07.816927Z", "start_time": "2026-03-04T15:04:07.787446Z" } }, "source": [ "@dataclass\n", "class Config:\n", " \"\"\"Fixed configuration for this simplified APC MNIST tutorial notebook.\"\"\"\n", "\n", " # Reproducibility and dataset location.\n", " seed: int = 0\n", " data_dir: Path = Path(\"./data/mnist\")\n", " download: bool = True\n", "\n", " # Input representation used in this notebook.\n", " dist_data: str = \"binomial\"\n", " image_size: int = 32\n", " n_bits: int = 8\n", "\n", " # Runtime execution settings.\n", " device: str = \"auto\"\n", " precision: str = \"bf16-mixed\"\n", " num_workers: int = 0\n", "\n", " # Training budget and dataset caps for notebook speed.\n", " iters: int = 1000\n", " batch_size: int = 128\n", " val_size: int = 10_000\n", " max_train_samples: int | None = 60_000\n", " max_val_samples: int | None = 2_000\n", " max_test_samples: int | None = 2_000\n", "\n", " # Optimizer and scheduler hyperparameters.\n", " lr_encoder: float = 1e-1\n", " lr_decoder: float = 1e-3\n", " weight_decay: float = 0.0\n", " lr_gamma: float = 0.1\n", "\n", " # Encoder/decoder architecture controls.\n", " latent_dim: int = 64\n", " conv_channels: int = 64\n", " conv_depth: int = 3\n", " conv_latent_depth: int = 0\n", " conv_use_sum_conv: bool = False\n", " num_repetitions: int = 1\n", "\n", " nn_hidden: int = 64\n", " nn_res_hidden: int = 16\n", " nn_res_layers: int = 2\n", " nn_scales: int = 2\n", " nn_bn: bool = True\n", " nn_out_activation: str = \"identity\"\n", "\n", " # APC objective weighting and sampling setup.\n", " rec_loss: str = \"mse\"\n", " sample_tau: float = 1.0\n", " w_rec: float = 1.0\n", " w_kld: float = 1.0\n", " w_nll: float = 1.0\n", "\n", " grad_clip_norm: float | None = None\n", " warmup_pct: float = 2.0\n", " num_vis: int = 8\n", "\n", "\n", "cfg = Config()\n", "cfg\n" ], "outputs": [ { "data": { "text/plain": [ "Config(seed=0, data_dir=PosixPath('data/mnist'), download=True, dist_data='binomial', image_size=32, n_bits=8, device='auto', precision='bf16-mixed', num_workers=0, iters=1000, batch_size=128, val_size=10000, max_train_samples=60000, max_val_samples=2000, max_test_samples=2000, lr_encoder=0.1, lr_decoder=0.001, weight_decay=0.0, lr_gamma=0.1, latent_dim=64, conv_channels=64, conv_depth=3, conv_latent_depth=0, conv_use_sum_conv=False, num_repetitions=1, nn_hidden=64, nn_res_hidden=16, nn_res_layers=2, nn_scales=2, nn_bn=True, nn_out_activation='identity', rec_loss='mse', sample_tau=1.0, w_rec=1.0, w_kld=1.0, w_nll=1.0, grad_clip_norm=None, warmup_pct=2.0, num_vis=8)" ] }, "execution_count": 36, "metadata": {}, "output_type": "execute_result" } ], "execution_count": 36 }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Configuration Notes\n", "\n", "When you start experimenting, change one group of settings at a time:\n", "\n", "- `iters` / `batch_size`: runtime and optimization smoothness\n", "- `latent_dim` + conv settings: encoder capacity\n", "- `w_rec`, `w_kld`, `w_nll`: reconstruction vs regularization trade-off\n", "\n", "Small, isolated changes make the loss table and reconstructions easier to interpret.\n" ], "id": "8ca06aaa" }, { "cell_type": "markdown", "id": "e75595f1", "metadata": {}, "source": [ "## Minimal Helpers (Data, Model, Train)\n", "\n", "This helper block keeps the rest of the notebook compact by grouping the end-to-end APC workflow:\n", "\n", "1. Build deterministic dataset splits and loaders\n", "2. Construct the Conv-PC encoder + neural decoder APC model\n", "3. Configure optimizer and learning-rate schedule\n", "4. Train for a fixed iteration budget with periodic validation\n", "5. Build a reconstruction grid for qualitative inspection\n" ] }, { "cell_type": "code", "id": "fd8e841b", "metadata": { "execution": { "iopub.execute_input": "2026-03-04T14:46:22.462662Z", "iopub.status.busy": "2026-03-04T14:46:22.462585Z", "iopub.status.idle": "2026-03-04T14:46:22.472605Z", "shell.execute_reply": "2026-03-04T14:46:22.472182Z" }, "ExecuteTime": { "end_time": "2026-03-04T15:04:07.833334Z", "start_time": "2026-03-04T15:04:07.818042Z" } }, "source": [ "def seed_everything(seed: int) -> None:\n", " \"\"\"Seed torch RNGs for reproducible notebook runs.\"\"\"\n", " torch.manual_seed(seed)\n", " if torch.cuda.is_available():\n", " torch.cuda.manual_seed_all(seed)\n", "\n", "\n", "def resolve_device(name: str) -> torch.device:\n", " \"\"\"Resolve the runtime device from `auto`, `cpu`, or `cuda`.\"\"\"\n", " if name == \"auto\":\n", " return torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", " return torch.device(name)\n", "\n", "\n", "def build_fabric(*, device: torch.device, precision: str) -> L.Fabric:\n", " \"\"\"Create a single-device Lightning Fabric runtime.\"\"\"\n", " accelerator = \"cuda\" if device.type == \"cuda\" else \"cpu\"\n", " return L.Fabric(accelerator=accelerator, devices=1, precision=precision)\n", "\n", "\n", "class QuantizeToNBits:\n", " \"\"\"Torchvision transform: map pixels to integer counts in [0, 2^n_bits-1].\"\"\"\n", "\n", " def __init__(self, n_bits: int) -> None:\n", " self.n_bits = int(n_bits)\n", "\n", " def __call__(self, x: torch.Tensor) -> torch.Tensor:\n", " max_value = float(2**self.n_bits - 1)\n", " return torch.floor(x * max_value)\n", "\n", "\n", "def _cap_subset(subset: torch.utils.data.Subset, max_samples: int | None) -> torch.utils.data.Subset:\n", " \"\"\"Optionally cap a subset length for quick notebook iterations.\"\"\"\n", " if max_samples is None or max_samples >= len(subset):\n", " return subset\n", " return torch.utils.data.Subset(subset.dataset, subset.indices[:max_samples])\n", "\n", "\n", "def build_loaders(cfg: Config) -> tuple[DataLoader, DataLoader, DataLoader]:\n", " \"\"\"Build train/val/test dataloaders for MNIST with fixed binomial preprocessing.\"\"\"\n", " # Keep preprocessing fixed so architecture/loss changes are easier to compare.\n", " if cfg.dist_data != \"binomial\":\n", " raise ValueError(f\"This notebook only supports dist_data='binomial', got {cfg.dist_data!r}\")\n", "\n", " tfm = transforms.Compose(\n", " [\n", " transforms.Resize((cfg.image_size, cfg.image_size)),\n", " transforms.ToTensor(),\n", " QuantizeToNBits(cfg.n_bits),\n", " ]\n", " )\n", "\n", " train_full = datasets.MNIST(root=str(cfg.data_dir), train=True, transform=tfm, download=cfg.download)\n", " test_full = datasets.MNIST(root=str(cfg.data_dir), train=False, transform=tfm, download=cfg.download)\n", "\n", " if cfg.val_size <= 0 or cfg.val_size >= len(train_full):\n", " raise ValueError(f\"val_size must be in [1, {len(train_full)-1}], got {cfg.val_size}\")\n", "\n", " # Use a seeded generator so train/validation split is deterministic.\n", " gen = torch.Generator().manual_seed(cfg.seed)\n", " train_subset, val_subset = random_split(\n", " train_full, [len(train_full) - cfg.val_size, cfg.val_size], generator=gen\n", " )\n", "\n", " train_subset = _cap_subset(train_subset, cfg.max_train_samples)\n", " val_subset = _cap_subset(val_subset, cfg.max_val_samples)\n", " test_subset = _cap_subset(torch.utils.data.Subset(test_full, list(range(len(test_full)))), cfg.max_test_samples)\n", "\n", " # Pin memory only when CUDA is used to speed up host-to-device copies.\n", " pin = torch.cuda.is_available() and resolve_device(cfg.device).type == \"cuda\"\n", " train_loader = DataLoader(train_subset, batch_size=cfg.batch_size, shuffle=True, num_workers=cfg.num_workers, pin_memory=pin)\n", " val_loader = DataLoader(val_subset, batch_size=cfg.batch_size, shuffle=False, num_workers=cfg.num_workers, pin_memory=pin)\n", " test_loader = DataLoader(test_subset, batch_size=cfg.batch_size, shuffle=False, num_workers=cfg.num_workers, pin_memory=pin)\n", " return train_loader, val_loader, test_loader\n", "\n", "\n", "def make_x_leaf_factory(n_bits: int):\n", " \"\"\"Create the observed-data leaf factory (Binomial).\"\"\"\n", " total_count_tensor = torch.tensor(float(2**n_bits - 1))\n", "\n", " def _factory(scope_indices: list[int], out_channels: int, num_repetitions: int) -> LeafModule:\n", " shape = (len(scope_indices), out_channels, num_repetitions)\n", " probs = 0.5 + (torch.rand(shape) - 0.5) * 0.2\n", " return Binomial(\n", " scope=scope_indices,\n", " out_channels=out_channels,\n", " num_repetitions=num_repetitions,\n", " total_count=total_count_tensor,\n", " probs=probs,\n", " )\n", "\n", " return _factory\n", "\n", "\n", "def make_z_leaf_factory():\n", " \"\"\"Create the latent leaf factory (Normal).\"\"\"\n", "\n", " def _factory(scope_indices: list[int], out_channels: int, num_repetitions: int) -> LeafModule:\n", " shape = (len(scope_indices), out_channels, num_repetitions)\n", " loc = torch.randn(shape)\n", " logvar = torch.randn(shape)\n", " scale = torch.exp(0.5 * logvar)\n", " return Normal(\n", " scope=scope_indices,\n", " out_channels=out_channels,\n", " num_repetitions=num_repetitions,\n", " loc=loc,\n", " scale=scale,\n", " )\n", "\n", " return _factory\n", "\n", "\n", "def build_model(cfg: Config) -> AutoencodingPC:\n", " \"\"\"Build a fixed Conv-PC encoder with an NN decoder.\"\"\"\n", " encoder = ConvPcJointEncoder(\n", " input_height=cfg.image_size,\n", " input_width=cfg.image_size,\n", " input_channels=1,\n", " latent_dim=cfg.latent_dim,\n", " channels=cfg.conv_channels,\n", " depth=cfg.conv_depth,\n", " kernel_size=2,\n", " num_repetitions=cfg.num_repetitions,\n", " use_sum_conv=cfg.conv_use_sum_conv,\n", " latent_depth=cfg.conv_latent_depth,\n", " architecture=\"reference\",\n", " perm_latents=False,\n", " x_leaf_factory=make_x_leaf_factory(cfg.n_bits),\n", " z_leaf_factory=make_z_leaf_factory(),\n", " )\n", "\n", " decoder = NeuralDecoder2D(\n", " latent_dim=cfg.latent_dim,\n", " output_shape=(1, cfg.image_size, cfg.image_size),\n", " num_hidden=cfg.nn_hidden,\n", " num_res_hidden=cfg.nn_res_hidden,\n", " num_res_layers=cfg.nn_res_layers,\n", " num_scales=cfg.nn_scales,\n", " bn=cfg.nn_bn,\n", " out_activation=cfg.nn_out_activation,\n", " )\n", "\n", " apc_config = ApcConfig(\n", " latent_dim=cfg.latent_dim,\n", " rec_loss=cfg.rec_loss,\n", " sample_tau=cfg.sample_tau,\n", " loss_weights=ApcLossWeights(rec=cfg.w_rec, kld=cfg.w_kld, nll=cfg.w_nll),\n", " )\n", " return AutoencodingPC(encoder=encoder, decoder=decoder, config=apc_config)\n", "\n", "\n", "def build_optimizer(cfg: Config, model: AutoencodingPC) -> Optimizer:\n", " \"\"\"Build AdamW with separate encoder/decoder learning rates.\"\"\"\n", " return AdamW(\n", " [\n", " {\"params\": model.encoder.parameters(), \"lr\": cfg.lr_encoder},\n", " {\"params\": model.decoder.parameters(), \"lr\": cfg.lr_decoder},\n", " ],\n", " weight_decay=cfg.weight_decay,\n", " )\n", "\n", "\n", "def build_scheduler(cfg: Config, optimizer: Optimizer) -> MultiStepLR:\n", " \"\"\"Build fixed MultiStepLR with milestones at 66% and 90% of training.\"\"\"\n", " milestones = sorted({max(1, int(0.66 * cfg.iters)), max(1, int(0.9 * cfg.iters))})\n", " return MultiStepLR(optimizer, milestones=milestones, gamma=cfg.lr_gamma)\n", "\n", "\n", "def _extract_x(batch: torch.Tensor | tuple | list) -> torch.Tensor:\n", " \"\"\"Extract the image tensor from `(x, y)` dataset batches.\"\"\"\n", " if isinstance(batch, torch.Tensor):\n", " return batch\n", " if isinstance(batch, (tuple, list)) and len(batch) > 0 and isinstance(batch[0], torch.Tensor):\n", " return batch[0]\n", " raise TypeError(f\"Unsupported batch type: {type(batch)}\")\n", "\n", "\n", "def collect_vis_batch(loader: DataLoader, device: torch.device, num_vis: int) -> torch.Tensor:\n", " \"\"\"Collect a small fixed test batch used for reconstruction visualization.\"\"\"\n", " chunks: list[torch.Tensor] = []\n", " n = 0\n", " for batch in loader:\n", " x = _extract_x(batch).to(device)\n", " take = min(num_vis - n, x.shape[0])\n", " chunks.append(x[:take])\n", " n += take\n", " if n >= num_vis:\n", " break\n", " if not chunks:\n", " raise RuntimeError(\"Could not collect visualization samples\")\n", " return torch.cat(chunks, dim=0)\n", "\n", "\n", "def _to_image_batch(x: torch.Tensor, image_size: int) -> torch.Tensor:\n", " \"\"\"Ensure tensors are in image shape `(B, 1, H, W)`.\"\"\"\n", " if x.dim() == 4 and x.shape[1:] == (1, image_size, image_size):\n", " return x\n", " if x.dim() == 2 and x.shape[1] == image_size * image_size:\n", " return x.view(-1, 1, image_size, image_size)\n", " raise ValueError(f\"Unexpected shape {tuple(x.shape)}\")\n", "\n", "\n", "def build_recon_grid(model: AutoencodingPC, x_batch: torch.Tensor, cfg: Config) -> torch.Tensor:\n", " \"\"\"Return an inline visualization grid (top row data, bottom row reconstruction).\"\"\"\n", " model.eval()\n", " with torch.no_grad():\n", " x_rec = model.reconstruct(x_batch)\n", "\n", " denom = float(2**cfg.n_bits - 1)\n", " x_img = _to_image_batch(x_batch.detach().cpu(), cfg.image_size).float().div(denom).clamp(0.0, 1.0)\n", " x_rec_img = _to_image_batch(x_rec.detach().cpu(), cfg.image_size).float().div(denom).clamp(0.0, 1.0)\n", " return make_grid(torch.cat([x_img, x_rec_img], dim=0), nrow=x_img.shape[0], padding=1, pad_value=1.0)\n", "\n", "\n", "def _progress_map(iters: int) -> dict[int, int]:\n", " \"\"\"Map training steps to progress percentages: 10, 20, ..., 100.\"\"\"\n", " marks: dict[int, int] = {}\n", " for p in range(1, 11):\n", " step = max(1, int(round(iters * p / 10)))\n", " marks[step] = p * 10\n", " return marks\n", "\n", "\n", "def train_apc_iters(\n", " *,\n", " fabric: L.Fabric,\n", " model: AutoencodingPC,\n", " train_loader: DataLoader,\n", " val_loader: DataLoader,\n", " optimizer: Optimizer,\n", " scheduler: MultiStepLR,\n", " cfg: Config,\n", ") -> list[dict[str, float]]:\n", " \"\"\"Train for a fixed iteration budget and log exactly 10 progress lines.\"\"\"\n", " history: list[dict[str, float]] = []\n", " train_iter = iter(train_loader)\n", "\n", " base_lrs = [float(group[\"lr\"]) for group in optimizer.param_groups]\n", " # Warmup duration is expressed as a fraction of total iterations.\n", " warmup_steps = int(cfg.iters * cfg.warmup_pct / 100.0)\n", " progress_marks = _progress_map(cfg.iters)\n", "\n", " # Track running means between progress checkpoints.\n", " window_totals = {\"rec\": 0.0, \"kld\": 0.0, \"nll\": 0.0, \"total\": 0.0}\n", " window_steps = 0\n", "\n", " for step in range(1, cfg.iters + 1):\n", " try:\n", " batch = next(train_iter)\n", " except StopIteration:\n", " train_iter = iter(train_loader)\n", " batch = next(train_iter)\n", "\n", " x = _extract_x(batch).to(fabric.device)\n", "\n", " model.train()\n", " optimizer.zero_grad(set_to_none=True)\n", " losses = model.loss_components(x)\n", " fabric.backward(losses[\"total\"])\n", "\n", " if cfg.grad_clip_norm is not None:\n", " fabric.clip_gradients(model, optimizer, max_norm=cfg.grad_clip_norm)\n", "\n", " optimizer.step()\n", " scheduler.step()\n", "\n", " # Keep the script-style warmup behavior in compact form.\n", " if warmup_steps > 0 and step <= warmup_steps:\n", " factor = math.exp(-5.0 * (1.0 - (step / warmup_steps)) ** 2)\n", " for group, base_lr in zip(optimizer.param_groups, base_lrs):\n", " group[\"lr\"] = base_lr * factor\n", "\n", " for k in window_totals:\n", " window_totals[k] += float(losses[k].item())\n", " window_steps += 1\n", "\n", " if step not in progress_marks:\n", " continue\n", "\n", " # Run validation only at progress checkpoints to keep notebook runtime low.\n", " val = evaluate_apc(model, val_loader, batch_size=cfg.batch_size)\n", " row = {\n", " \"iter\": float(step),\n", " \"progress_pct\": float(progress_marks[step]),\n", " \"train_total\": window_totals[\"total\"] / window_steps,\n", " \"train_rec\": window_totals[\"rec\"] / window_steps,\n", " \"train_kld\": window_totals[\"kld\"] / window_steps,\n", " \"train_nll\": window_totals[\"nll\"] / window_steps,\n", " \"val_total\": float(val[\"total\"]),\n", " \"val_rec\": float(val[\"rec\"]),\n", " \"val_kld\": float(val[\"kld\"]),\n", " \"val_nll\": float(val[\"nll\"]),\n", " }\n", " history.append(row)\n", "\n", " print(\n", " f\"[APC] {int(row['progress_pct']):3d}% ({step}/{cfg.iters}) \"\n", " f\"train_total={row['train_total']:.4f} val_total={row['val_total']:.4f}\"\n", " )\n", "\n", " window_totals = {\"rec\": 0.0, \"kld\": 0.0, \"nll\": 0.0, \"total\": 0.0}\n", " window_steps = 0\n", "\n", " return history\n" ], "outputs": [], "execution_count": 37 }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Before Running Training\n", "\n", "The run cell follows this order:\n", "\n", "1. Set seed/device/Fabric\n", "2. Build loaders, model, optimizer, scheduler\n", "3. Move components into Fabric-managed runtime\n", "4. Train, evaluate on test set, and prepare plotting payloads\n" ], "id": "49bb8dc1" }, { "cell_type": "markdown", "id": "140dbcf9", "metadata": {}, "source": [ "## Training\n", "\n", "The following wires all helper functions into one runnable pipeline and prints compact progress logs at 10% checkpoints.\n" ] }, { "cell_type": "code", "id": "aa9d03fd", "metadata": { "execution": { "iopub.execute_input": "2026-03-04T14:46:22.473716Z", "iopub.status.busy": "2026-03-04T14:46:22.473634Z", "iopub.status.idle": "2026-03-04T14:51:47.439238Z", "shell.execute_reply": "2026-03-04T14:51:47.438515Z" }, "ExecuteTime": { "end_time": "2026-03-04T15:11:34.130517Z", "start_time": "2026-03-04T15:04:07.833800Z" } }, "source": [ "# Seed first so data splits and initial parameters remain reproducible.\n", "seed_everything(cfg.seed)\n", "\n", "# Resolve backend and initialize a single-device Fabric runtime.\n", "device = resolve_device(cfg.device)\n", "fabric = build_fabric(device=device, precision=cfg.precision)\n", "\n", "# Build data pipeline and APC model from the fixed config.\n", "train_loader, val_loader, test_loader = build_loaders(cfg)\n", "model = build_model(cfg)\n", "optimizer = build_optimizer(cfg, model)\n", "scheduler = build_scheduler(cfg, optimizer)\n", "\n", "# Let Fabric place loaders/model/optimizer on the selected runtime device.\n", "train_loader, val_loader, test_loader = fabric.setup_dataloaders(train_loader, val_loader, test_loader)\n", "# Keep a fixed mini-batch for consistent reconstruction snapshots.\n", "vis_batch = collect_vis_batch(test_loader, device=device, num_vis=cfg.num_vis)\n", "model, optimizer = fabric.setup(model, optimizer)\n", "\n", "print(f\"[APC] Device: {device}\")\n", "print(\n", " f\"[APC] Dataset sizes: train={len(train_loader.dataset)}, val={len(val_loader.dataset)}, test={len(test_loader.dataset)}\"\n", ")\n", "print(\n", " f\"[APC] Fixed setup: Conv-PC + NN decoder, AdamW, MultiStepLR(66/90), \"\n", " f\"batch_size={cfg.batch_size}, iters={cfg.iters}\"\n", ")\n", "\n", "# Run the fixed-iteration training loop.\n", "history = train_apc_iters(\n", " fabric=fabric,\n", " model=model,\n", " train_loader=train_loader,\n", " val_loader=val_loader,\n", " optimizer=optimizer,\n", " scheduler=scheduler,\n", " cfg=cfg,\n", ")\n", "\n", "# Evaluate on held-out data and unwrap model if Fabric wrapped it.\n", "test_metrics = evaluate_apc(model, test_loader, batch_size=cfg.batch_size)\n", "model_unwrapped = model.module if hasattr(model, \"module\") else model\n", "\n", "history_df = pd.DataFrame(history)\n", "recon_grid = build_recon_grid(model_unwrapped, vis_batch, cfg)\n", "\n", "inline_payload = {\n", " \"config\": {k: (str(v) if isinstance(v, Path) else v) for k, v in asdict(cfg).items()},\n", " \"apc_config\": asdict(model_unwrapped.config),\n", " \"test_metrics\": test_metrics,\n", "}\n", "\n", "print(\"[APC] Test metrics:\", test_metrics)" ], "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Using bfloat16 Automatic Mixed Precision (AMP)\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[APC] Device: cpu\n", "[APC] Dataset sizes: train=50000, val=2000, test=2000\n", "[APC] Fixed setup: Conv-PC + NN decoder, AdamW, MultiStepLR(66/90), batch_size=128, iters=1000\n", "[APC] 10% (100/1000) train_total=4504533.5300 val_total=2967835.8420\n", "[APC] 20% (200/1000) train_total=2212066.5950 val_total=1748269.8230\n", "[APC] 30% (300/1000) train_total=1558427.6600 val_total=1445993.7350\n", "[APC] 40% (400/1000) train_total=1366740.8050 val_total=1326261.4410\n", "[APC] 50% (500/1000) train_total=1273774.9087 val_total=1247038.0190\n", "[APC] 60% (600/1000) train_total=1229806.9688 val_total=1215286.8020\n", "[APC] 70% (700/1000) train_total=1168145.4112 val_total=1143536.5570\n", "[APC] 80% (800/1000) train_total=1128152.8813 val_total=1124754.1210\n", "[APC] 90% (900/1000) train_total=1105665.3219 val_total=1105524.4530\n", "[APC] 100% (1000/1000) train_total=1098820.8000 val_total=1104452.4800\n", "[APC] Test metrics: {'rec': 1058982.502, 'kld': 904.2743515625, 'nll': 20949.222015625, 'total': 1080835.999}\n", "\n", "[APC] Inline payload (no file output):\n", "{\n", " \"config\": {\n", " \"seed\": 0,\n", " \"data_dir\": \"data/mnist\",\n", " \"download\": true,\n", " \"dist_data\": \"binomial\",\n", " \"image_size\": 32,\n", " \"n_bits\": 8,\n", " \"device\": \"auto\",\n", " \"precision\": \"bf16-mixed\",\n", " \"num_workers\": 0,\n", " \"iters\": 1000,\n", " \"batch_size\": 128,\n", " \"val_size\": 10000,\n", " \"max_train_samples\": 60000,\n", " \"max_val_samples\": 2000,\n", " \"max_test_samples\": 2000,\n", " \"lr_encoder\": 0.1,\n", " \"lr_decoder\": 0.001,\n", " \"weight_decay\": 0.0,\n", " \"lr_gamma\": 0.1,\n", " \"latent_dim\": 64,\n", " \"conv_channels\": 64,\n", " \"conv_depth\": 3,\n", " \"conv_latent_depth\": 0,\n", " \"conv_use_sum_conv\": false,\n", " \"num_repetitions\": 1,\n", " \"nn_hidden\": 64,\n", " \"nn_res_hidden\": 16,\n", " \"nn_res_layers\": 2,\n", " \"nn_scales\": 2,\n", " \"nn_bn\": true,\n", " \"nn_out_activation\": \"identity\",\n", " \"rec_loss\": \"mse\",\n", " \"sample_tau\": 1.0,\n", " \"w_rec\": 1.0,\n", " \"w_kld\": 1.0,\n", " \"w_nll\": 1.0,\n", " \"grad_clip_norm\": null,\n", " \"warmup_pct\": 2.0,\n", " \"num_vis\": 8\n", " },\n", " \"apc_config\": {\n", " \"latent_dim\": 64,\n", " \"rec_loss\": \"mse\",\n", " \"sample_tau\": 1.0,\n", " \"train_decode_mpe\": false,\n", " \"nll_x_and_z\": true,\n", " \"loss_weights\": {\n", " \"rec\": 1.0,\n", " \"kld\": 1.0,\n", " \"nll\": 1.0\n", " }\n", " },\n", " \"test_metrics\": {\n", " \"rec\": 1058982.502,\n", " \"kld\": 904.2743515625,\n", " \"nll\": 20949.222015625,\n", " \"total\": 1080835.999\n", " }\n", "}\n" ] } ], "execution_count": 38 }, { "cell_type": "markdown", "id": "239d3b6d", "metadata": {}, "source": "## Results" }, { "cell_type": "code", "id": "ed5c8ca0", "metadata": { "execution": { "iopub.execute_input": "2026-03-04T14:51:47.441367Z", "iopub.status.busy": "2026-03-04T14:51:47.441224Z", "iopub.status.idle": "2026-03-04T14:51:47.558455Z", "shell.execute_reply": "2026-03-04T14:51:47.557971Z" }, "ExecuteTime": { "end_time": "2026-03-04T15:11:35.524078Z", "start_time": "2026-03-04T15:11:34.162333Z" } }, "source": [ "# Show the checkpoint-level training history table.\n", "display(history_df)\n", "# Plot train/validation total loss across 10% progress checkpoints.\n", "plt.figure(figsize=(9, 4))\n", "plt.plot(history_df[\"iter\"], history_df[\"train_total\"], marker=\"o\", label=\"train_total\")\n", "plt.plot(history_df[\"iter\"], history_df[\"val_total\"], marker=\"o\", label=\"val_total\")\n", "plt.xlabel(\"Iteration\")\n", "plt.ylabel(\"Loss\")\n", "plt.title(\"APC Loss Curve\")\n", "plt.grid(alpha=0.3)\n", "plt.legend()\n", "plt.show()\n", "\n", "# Display originals (top row) against reconstructions (bottom row).\n", "img = recon_grid.detach().cpu()\n", "plt.figure(figsize=(12, 3))\n", "if img.shape[0] == 1:\n", " plt.imshow(img.squeeze(0), cmap=\"gray\", vmin=0.0, vmax=1.0)\n", "else:\n", " plt.imshow(img.permute(1, 2, 0).clamp(0.0, 1.0))\n", "plt.axis(\"off\")\n", "plt.title(\"Reconstructions (top: data, bottom: reconstruction)\")\n", "plt.show()\n" ], "outputs": [ { "data": { "text/plain": [ " iter progress_pct train_total train_rec train_kld train_nll \\\n", "0 100.0 10.0 4.504534e+06 4.446900e+06 186.001488 57447.176172 \n", "1 200.0 20.0 2.212067e+06 2.185646e+06 513.629837 25907.362402 \n", "2 300.0 30.0 1.558428e+06 1.534141e+06 695.762187 23591.010215 \n", "3 400.0 40.0 1.366741e+06 1.343259e+06 773.251157 22708.204551 \n", "4 500.0 50.0 1.273775e+06 1.250953e+06 835.988552 21985.431719 \n", "5 600.0 60.0 1.229807e+06 1.207507e+06 887.402727 21412.433613 \n", "6 700.0 70.0 1.168145e+06 1.145592e+06 911.301548 21642.418105 \n", "7 800.0 80.0 1.128153e+06 1.105444e+06 921.173979 21787.575938 \n", "8 900.0 90.0 1.105665e+06 1.083133e+06 917.388387 21614.657539 \n", "9 1000.0 100.0 1.098821e+06 1.076440e+06 919.010759 21461.837246 \n", "\n", " val_total val_rec val_kld val_nll \n", "0 2967835.842 2.939025e+06 349.067951 28461.813141 \n", "1 1748269.823 1.723398e+06 635.257547 24236.865609 \n", "2 1445993.735 1.422170e+06 747.875771 23076.016375 \n", "3 1326261.441 1.303142e+06 799.736663 22320.137344 \n", "4 1247038.019 1.224443e+06 862.661738 21732.708234 \n", "5 1215286.802 1.192928e+06 896.726700 21462.263781 \n", "6 1143536.557 1.120679e+06 912.463941 21945.340609 \n", "7 1124754.121 1.102065e+06 914.862308 21774.091844 \n", "8 1105524.453 1.083016e+06 915.413808 21593.535422 \n", "9 1104452.480 1.081957e+06 917.645335 21577.747719 " ], "text/html": [ "
| \n", " | iter | \n", "progress_pct | \n", "train_total | \n", "train_rec | \n", "train_kld | \n", "train_nll | \n", "val_total | \n", "val_rec | \n", "val_kld | \n", "val_nll | \n", "
|---|---|---|---|---|---|---|---|---|---|---|
| 0 | \n", "100.0 | \n", "10.0 | \n", "4.504534e+06 | \n", "4.446900e+06 | \n", "186.001488 | \n", "57447.176172 | \n", "2967835.842 | \n", "2.939025e+06 | \n", "349.067951 | \n", "28461.813141 | \n", "
| 1 | \n", "200.0 | \n", "20.0 | \n", "2.212067e+06 | \n", "2.185646e+06 | \n", "513.629837 | \n", "25907.362402 | \n", "1748269.823 | \n", "1.723398e+06 | \n", "635.257547 | \n", "24236.865609 | \n", "
| 2 | \n", "300.0 | \n", "30.0 | \n", "1.558428e+06 | \n", "1.534141e+06 | \n", "695.762187 | \n", "23591.010215 | \n", "1445993.735 | \n", "1.422170e+06 | \n", "747.875771 | \n", "23076.016375 | \n", "
| 3 | \n", "400.0 | \n", "40.0 | \n", "1.366741e+06 | \n", "1.343259e+06 | \n", "773.251157 | \n", "22708.204551 | \n", "1326261.441 | \n", "1.303142e+06 | \n", "799.736663 | \n", "22320.137344 | \n", "
| 4 | \n", "500.0 | \n", "50.0 | \n", "1.273775e+06 | \n", "1.250953e+06 | \n", "835.988552 | \n", "21985.431719 | \n", "1247038.019 | \n", "1.224443e+06 | \n", "862.661738 | \n", "21732.708234 | \n", "
| 5 | \n", "600.0 | \n", "60.0 | \n", "1.229807e+06 | \n", "1.207507e+06 | \n", "887.402727 | \n", "21412.433613 | \n", "1215286.802 | \n", "1.192928e+06 | \n", "896.726700 | \n", "21462.263781 | \n", "
| 6 | \n", "700.0 | \n", "70.0 | \n", "1.168145e+06 | \n", "1.145592e+06 | \n", "911.301548 | \n", "21642.418105 | \n", "1143536.557 | \n", "1.120679e+06 | \n", "912.463941 | \n", "21945.340609 | \n", "
| 7 | \n", "800.0 | \n", "80.0 | \n", "1.128153e+06 | \n", "1.105444e+06 | \n", "921.173979 | \n", "21787.575938 | \n", "1124754.121 | \n", "1.102065e+06 | \n", "914.862308 | \n", "21774.091844 | \n", "
| 8 | \n", "900.0 | \n", "90.0 | \n", "1.105665e+06 | \n", "1.083133e+06 | \n", "917.388387 | \n", "21614.657539 | \n", "1105524.453 | \n", "1.083016e+06 | \n", "915.413808 | \n", "21593.535422 | \n", "
| 9 | \n", "1000.0 | \n", "100.0 | \n", "1.098821e+06 | \n", "1.076440e+06 | \n", "919.010759 | \n", "21461.837246 | \n", "1104452.480 | \n", "1.081957e+06 | \n", "917.645335 | \n", "21577.747719 | \n", "