Source code for spflow.zoo.apc.model

"""High-level APC model orchestration.

This module combines an APC encoder (tractable probabilistic circuit over ``X,Z``)
with an optional neural decoder and exposes the paper-style composite objective:

``total = w_rec * rec + w_kld * kld + w_nll * nll``.
"""

from __future__ import annotations

import torch
from einops import rearrange
from torch import Tensor
from torch import nn
from torch.nn import functional as F

from spflow.exceptions import InvalidParameterError, UnsupportedOperationError
from spflow.zoo.apc.debug_trace import trace_tensor
from spflow.zoo.apc.config import ApcConfig
from spflow.zoo.apc.encoders.base import ApcEncoder, LatentStats


[docs] class AutoencodingPC(nn.Module): """APC model combining a tractable encoder and an optional decoder. If ``decoder`` is ``None``, decoding is delegated to the encoder's evidence-conditioned ``decode`` method. """
[docs] def __init__( self, encoder: ApcEncoder, decoder: nn.Module | None, config: ApcConfig, ) -> None: """Initialize an APC model. Args: encoder: APC-compatible encoder implementation. decoder: Optional neural decoder mapping ``z -> x``. config: APC model and loss configuration. """ super().__init__() self.encoder = encoder self.decoder = decoder self.config = config
[docs] def encode(self, x: Tensor, *, mpe: bool = False, tau: float | None = None) -> Tensor: """Encode observed data into latent samples. Args: x: Observation tensor. mpe: Whether to use deterministic MPE routing. tau: Optional sampling temperature override. Returns: Latent samples ``z``. """ trace_tensor("apc.encode.x_in", x) tau_eff = self.config.sample_tau if tau is None else tau z = self.encoder.encode(x, mpe=mpe, tau=tau_eff) # type: ignore[assignment] if isinstance(z, Tensor): trace_tensor("apc.encode.z_samples", z) return z # type: ignore[return-value]
[docs] def decode(self, z: Tensor, *, mpe: bool = False, tau: float | None = None) -> Tensor: """Decode latents into reconstructions/samples in data space. Args: z: Latent samples. mpe: Whether to use deterministic MPE routing when using encoder decode. tau: Optional sampling temperature override. Returns: Reconstructed/sample ``x`` tensor. """ trace_tensor("apc.decode.z_in", z) tau_eff = self.config.sample_tau if tau is None else tau if self.decoder is None: x_rec = self.encoder.decode(z, mpe=mpe, tau=tau_eff) trace_tensor("apc.decode.x_rec", x_rec) return x_rec x_rec = self.decoder(z) if x_rec.dim() > 2: x_rec = (x_rec + 1.0) / 2.0 * (2**self.config.n_bits - 1) trace_tensor("apc.decode.x_rec", x_rec) return x_rec
[docs] def reconstruct(self, x: Tensor, *, mpe: bool = False, tau: float | None = None) -> Tensor: """Reconstruct ``x`` by encoding to ``z`` and decoding back to data space.""" z = self.encode(x, mpe=mpe, tau=tau) return self.decode(z, mpe=mpe, tau=tau)
[docs] def sample_x(self, num_samples: int, *, tau: float | None = None) -> Tensor: """Sample synthetic observations by sampling ``z`` and decoding.""" z = self.sample_z(num_samples=num_samples, tau=tau) return self.decode(z, mpe=False, tau=tau)
[docs] def sample_z(self, num_samples: int, *, tau: float | None = None) -> Tensor: """Sample latents from the encoder prior.""" tau_eff = self.config.sample_tau if tau is None else tau return self.encoder.sample_prior_z(num_samples=num_samples, tau=tau_eff)
@staticmethod def _flatten_tensor(tensor: Tensor) -> Tensor: """Flatten all non-batch axes into a single feature axis.""" if tensor.dim() < 2: raise InvalidParameterError( f"Expected tensor with batch dimension and at least one feature axis, got shape {tuple(tensor.shape)}." ) return rearrange(tensor, "b ... -> b (...)") def _reconstruction_loss(self, x: Tensor, x_rec: Tensor) -> Tensor: """Compute reconstruction loss following the reference APC reduction. For image-like tensors (rank > 2), this applies the same legacy scaling used in the reference implementation before the reconstruction criterion. """ if x.dim() > 2: # Preserve the reference APC's historical scaling behavior exactly. # NOTE: This uses n_bits**2 - 1 (not 2**n_bits - 1). denom = float(self.config.n_bits**2 - 1) x = x / denom * 2.0 - 1.0 x_rec = x_rec / denom * 2.0 - 1.0 x_flat = self._flatten_tensor(x) x_rec_flat = self._flatten_tensor(x_rec) if x_flat.shape != x_rec_flat.shape: raise InvalidParameterError( f"Reconstruction shape mismatch: x has {tuple(x_flat.shape)}, x_rec has {tuple(x_rec_flat.shape)}." ) batch_size = x_flat.shape[0] if self.config.rec_loss == "mse": return F.mse_loss(x_rec_flat, x_flat, reduction="sum") / batch_size if self.config.rec_loss == "bce": return F.binary_cross_entropy(x_rec_flat, x_flat, reduction="sum") / batch_size raise InvalidParameterError(f"Unsupported rec_loss '{self.config.rec_loss}'.") @staticmethod def _kld_from_stats(stats: LatentStats) -> Tensor: """Compute mean KL divergence to standard Normal from moment stats.""" if stats.mu.shape != stats.logvar.shape: raise InvalidParameterError( f"Latent stats shape mismatch: mu {tuple(stats.mu.shape)} vs logvar {tuple(stats.logvar.shape)}." ) reduce_dims = tuple(range(1, stats.mu.dim())) if len(reduce_dims) == 0: raise InvalidParameterError("Latent stats must include at least one latent dimension.") kld_per_sample = 0.5 * (stats.mu.pow(2) + stats.logvar.exp() - 1.0 - stats.logvar).sum( dim=reduce_dims ) return kld_per_sample.mean()
[docs] def loss_components(self, x: Tensor) -> dict[str, Tensor]: """Compute APC loss components and intermediate tensors. Args: x: Observation tensor. Returns: Dictionary with scalar terms ``rec``, ``kld``, ``nll``, ``total`` and helpful intermediates ``z``, ``x_rec``, ``mu``, ``logvar``. """ del x raise UnsupportedOperationError( "APC KL-style training is unavailable after sample rollback. " "loss_components() is currently unsupported." )
[docs] def loss(self, x: Tensor) -> Tensor: """Return only the weighted total APC loss.""" del x raise UnsupportedOperationError( "APC KL-style training is unavailable after sample rollback. loss() is currently unsupported." )
[docs] def log_likelihood_x(self, x: Tensor) -> Tensor: """Compute encoder marginal log-likelihood ``log p(x)`` per sample.""" return self.encoder.log_likelihood_x(x)
[docs] def joint_log_likelihood(self, x: Tensor, z: Tensor) -> Tensor: """Compute encoder joint log-likelihood ``log p(x, z)`` per sample.""" return self.encoder.joint_log_likelihood(x, z)
[docs] def forward(self, x: Tensor) -> dict[str, Tensor]: """Alias for :meth:`loss_components` to integrate with training loops.""" return self.loss_components(x)
def extra_repr(self) -> str: return ( f"latent_dim={self.config.latent_dim}, rec_loss={self.config.rec_loss}, " f"weights=(rec={self.config.loss_weights.rec}, kld={self.config.loss_weights.kld}, " f"nll={self.config.loss_weights.nll})" )