Source code for spflow.zoo.apc.model

"""High-level APC model orchestration.

This module combines an APC encoder (tractable probabilistic circuit over ``X,Z``)
with an optional neural decoder and exposes the paper-style composite objective:

``total = w_rec * rec + w_kld * kld + w_nll * nll``.
"""

from __future__ import annotations

import torch
from einops import rearrange
from torch import Tensor
from torch import nn
from torch.nn import functional as F

from spflow.exceptions import InvalidParameterError, UnsupportedOperationError
from spflow.zoo.apc.debug_trace import trace_tensor
from spflow.zoo.apc.config import ApcConfig
from spflow.zoo.apc.encoders.base import ApcEncoder, LatentStats


[docs] class AutoencodingPC(nn.Module): """APC model combining a tractable encoder and an optional decoder. If ``decoder`` is ``None``, decoding is delegated to the encoder's evidence-conditioned ``decode`` method. """
[docs] def __init__( self, encoder: ApcEncoder, decoder: nn.Module | None, config: ApcConfig, ) -> None: """Initialize an APC model. Args: encoder: APC-compatible encoder implementation. decoder: Optional neural decoder mapping ``z -> x``. config: APC model and loss configuration. """ super().__init__() self.encoder = encoder self.decoder = decoder self.config = config
[docs] def encode(self, x: Tensor, *, mpe: bool = False, tau: float | None = None) -> Tensor: """Encode observed data into latent samples. Args: x: Observation tensor. mpe: Whether to use deterministic MPE routing. tau: Optional sampling temperature override. Returns: Latent samples ``z``. """ trace_tensor("apc.encode.x_in", x) tau_eff = self.config.sample_tau if tau is None else tau z = self.encoder.encode(x, mpe=mpe, tau=tau_eff) # type: ignore[assignment] if isinstance(z, Tensor): trace_tensor("apc.encode.z_samples", z) return z # type: ignore[return-value]
[docs] def decode(self, z: Tensor, *, mpe: bool = False, tau: float | None = None) -> Tensor: """Decode latents into reconstructions/samples in data space. Args: z: Latent samples. mpe: Whether to use deterministic MPE routing when using encoder decode. tau: Optional sampling temperature override. Returns: Reconstructed/sample ``x`` tensor. Any output scaling is caller-managed; this method does not apply range transforms. """ trace_tensor("apc.decode.z_in", z) tau_eff = self.config.sample_tau if tau is None else tau if self.decoder is None: x_rec = self.encoder.decode(z, mpe=mpe, tau=tau_eff) trace_tensor("apc.decode.x_rec", x_rec) return x_rec x_rec = self.decoder(z) trace_tensor("apc.decode.x_rec", x_rec) return x_rec
[docs] def reconstruct(self, x: Tensor, *, mpe: bool = False, tau: float | None = None) -> Tensor: """Reconstruct ``x`` by encoding to ``z`` and decoding back to data space.""" z = self.encode(x, mpe=mpe, tau=tau) return self.decode(z, mpe=mpe, tau=tau)
[docs] def sample_x(self, num_samples: int, *, tau: float | None = None) -> Tensor: """Sample synthetic observations by sampling ``z`` and decoding.""" z = self.sample_z(num_samples=num_samples, tau=tau) return self.decode(z, mpe=False, tau=tau)
[docs] def sample_z(self, num_samples: int, *, tau: float | None = None) -> Tensor: """Sample latents from the encoder prior.""" tau_eff = self.config.sample_tau if tau is None else tau return self.encoder.sample_prior_z(num_samples=num_samples, tau=tau_eff)
@staticmethod def _flatten_tensor(tensor: Tensor) -> Tensor: """Flatten all non-batch axes into a single feature axis.""" if tensor.dim() < 2: raise InvalidParameterError( f"Expected tensor with batch dimension and at least one feature axis, got shape {tuple(tensor.shape)}." ) return rearrange(tensor, "b ... -> b (...)") def _reconstruction_loss(self, x: Tensor, x_rec: Tensor) -> Tensor: """Compute reconstruction loss with feature-sum / batch-mean reduction. Input preprocessing is intentionally caller-controlled. This method compares ``x`` and ``x_rec`` directly. """ x_flat = self._flatten_tensor(x) x_rec_flat = self._flatten_tensor(x_rec) if x_flat.shape != x_rec_flat.shape: raise InvalidParameterError( f"Reconstruction shape mismatch: x has {tuple(x_flat.shape)}, x_rec has {tuple(x_rec_flat.shape)}." ) batch_size = x_flat.shape[0] if self.config.rec_loss == "mse": return F.mse_loss(x_rec_flat, x_flat, reduction="sum") / batch_size if self.config.rec_loss == "bce": return F.binary_cross_entropy(x_rec_flat, x_flat, reduction="sum") / batch_size raise InvalidParameterError(f"Unsupported rec_loss '{self.config.rec_loss}'.") @staticmethod def _kld_from_stats(stats: LatentStats) -> Tensor: """Compute mean KL divergence from exact encoder-provided per-sample KL.""" if not isinstance(stats.kld_per_sample, Tensor): raise InvalidParameterError("LatentStats.kld_per_sample must be a Tensor.") if stats.kld_per_sample.dim() != 1: raise InvalidParameterError( "LatentStats.kld_per_sample must have shape (batch,), " f"got {tuple(stats.kld_per_sample.shape)}." ) if stats.mu.shape[0] != stats.kld_per_sample.shape[0]: raise InvalidParameterError( "Latent stats batch mismatch: " f"mu batch {stats.mu.shape[0]} vs kld_per_sample batch {stats.kld_per_sample.shape[0]}." ) return stats.kld_per_sample.mean() @staticmethod def _float_scalar_zero(x: Tensor) -> Tensor: """Create a floating scalar zero on ``x``'s device.""" dtype = x.dtype if x.is_floating_point() else torch.get_default_dtype() return torch.zeros((), device=x.device, dtype=dtype)
[docs] def loss_components(self, x: Tensor) -> dict[str, Tensor]: """Compute APC loss components and intermediate tensors. Args: x: Observation tensor. Returns: Dictionary with scalar terms ``rec``, ``kld``, ``nll``, ``total`` and helpful intermediates ``z``, ``x_rec``, ``mu``, ``logvar``. """ tau_eff = self.config.sample_tau weights = self.config.loss_weights need_stats = (weights.kld > 0.0) or self.config.train_decode_mpe need_z_samples = (weights.rec > 0.0) or (weights.nll > 0.0) stats: LatentStats | None = None z_samples: Tensor | None = None components: dict[str, Tensor] = {} if need_stats or need_z_samples: encoded = self.encoder.encode( x, mpe=False, tau=tau_eff, return_latent_stats=need_stats, ) if need_stats: if ( not isinstance(encoded, tuple) or len(encoded) != 2 or not isinstance(encoded[0], LatentStats) or not isinstance(encoded[1], Tensor) ): raise UnsupportedOperationError( "Encoder must return (LatentStats, z_samples) when return_latent_stats=True." ) stats = encoded[0] z_samples = encoded[1] components["mu"] = stats.mu components["logvar"] = stats.logvar else: if not isinstance(encoded, Tensor): raise UnsupportedOperationError( "Encoder must return latent samples as a Tensor when return_latent_stats=False." ) z_samples = encoded if z_samples is not None: components["z"] = z_samples rec = self._float_scalar_zero(x) if weights.rec > 0.0: if z_samples is None: raise RuntimeError("Reconstruction loss requested but latent samples were not computed.") if self.config.train_decode_mpe: if stats is None: raise UnsupportedOperationError( "train_decode_mpe=True requires encoder latent stats; " "disable train_decode_mpe or use an encoder that supports latent stats." ) if not isinstance(stats.decode_latent, Tensor): raise UnsupportedOperationError( "train_decode_mpe=True requires LatentStats.decode_latent to be a Tensor." ) z_to_decode = stats.decode_latent else: z_to_decode = z_samples x_rec = self.decode(z_to_decode, mpe=False, tau=tau_eff) rec = self._reconstruction_loss(x, x_rec) components["x_rec"] = x_rec nll = self._float_scalar_zero(x) if weights.nll > 0.0: if self.config.nll_x_and_z: if z_samples is None: raise RuntimeError("Joint NLL requested but latent samples were not computed.") lls = self.joint_log_likelihood(x, z_samples) else: lls = self.log_likelihood_x(x) nll = -lls.sum() / x.shape[0] kld = self._float_scalar_zero(x) if weights.kld > 0.0: if stats is None: raise UnsupportedOperationError( "KL loss requires encoder latent stats, but they are unavailable. " "Set loss_weights.kld=0 or use an encoder that supports return_latent_stats=True." ) kld = self._kld_from_stats(stats) total = weights.rec * rec + weights.kld * kld + weights.nll * nll components["rec"] = rec components["kld"] = kld components["nll"] = nll components["total"] = total return components
[docs] def loss(self, x: Tensor) -> Tensor: """Return only the weighted total APC loss.""" return self.loss_components(x)["total"]
[docs] def log_likelihood_x(self, x: Tensor) -> Tensor: """Compute encoder marginal log-likelihood ``log p(x)`` per sample.""" return self.encoder.log_likelihood_x(x)
[docs] def joint_log_likelihood(self, x: Tensor, z: Tensor) -> Tensor: """Compute encoder joint log-likelihood ``log p(x, z)`` per sample.""" return self.encoder.joint_log_likelihood(x, z)
[docs] def forward(self, x: Tensor) -> dict[str, Tensor]: """Alias for :math:`loss_components` to integrate with training loops.""" return self.loss_components(x)
def extra_repr(self) -> str: return ( f"latent_dim={self.config.latent_dim}, rec_loss={self.config.rec_loss}, " f"weights=(rec={self.config.loss_weights.rec}, kld={self.config.loss_weights.kld}, " f"nll={self.config.loss_weights.nll})" )