"""High-level APC model orchestration.
This module combines an APC encoder (tractable probabilistic circuit over ``X,Z``)
with an optional neural decoder and exposes the paper-style composite objective:
``total = w_rec * rec + w_kld * kld + w_nll * nll``.
"""
from __future__ import annotations
import torch
from einops import rearrange
from torch import Tensor
from torch import nn
from torch.nn import functional as F
from spflow.exceptions import InvalidParameterError, UnsupportedOperationError
from spflow.zoo.apc.debug_trace import trace_tensor
from spflow.zoo.apc.config import ApcConfig
from spflow.zoo.apc.encoders.base import ApcEncoder, LatentStats
[docs]
class AutoencodingPC(nn.Module):
"""APC model combining a tractable encoder and an optional decoder.
If ``decoder`` is ``None``, decoding is delegated to the encoder's
evidence-conditioned ``decode`` method.
"""
[docs]
def __init__(
self,
encoder: ApcEncoder,
decoder: nn.Module | None,
config: ApcConfig,
) -> None:
"""Initialize an APC model.
Args:
encoder: APC-compatible encoder implementation.
decoder: Optional neural decoder mapping ``z -> x``.
config: APC model and loss configuration.
"""
super().__init__()
self.encoder = encoder
self.decoder = decoder
self.config = config
[docs]
def encode(self, x: Tensor, *, mpe: bool = False, tau: float | None = None) -> Tensor:
"""Encode observed data into latent samples.
Args:
x: Observation tensor.
mpe: Whether to use deterministic MPE routing.
tau: Optional sampling temperature override.
Returns:
Latent samples ``z``.
"""
trace_tensor("apc.encode.x_in", x)
tau_eff = self.config.sample_tau if tau is None else tau
z = self.encoder.encode(x, mpe=mpe, tau=tau_eff) # type: ignore[assignment]
if isinstance(z, Tensor):
trace_tensor("apc.encode.z_samples", z)
return z # type: ignore[return-value]
[docs]
def decode(self, z: Tensor, *, mpe: bool = False, tau: float | None = None) -> Tensor:
"""Decode latents into reconstructions/samples in data space.
Args:
z: Latent samples.
mpe: Whether to use deterministic MPE routing when using encoder decode.
tau: Optional sampling temperature override.
Returns:
Reconstructed/sample ``x`` tensor.
"""
trace_tensor("apc.decode.z_in", z)
tau_eff = self.config.sample_tau if tau is None else tau
if self.decoder is None:
x_rec = self.encoder.decode(z, mpe=mpe, tau=tau_eff)
trace_tensor("apc.decode.x_rec", x_rec)
return x_rec
x_rec = self.decoder(z)
if x_rec.dim() > 2:
x_rec = (x_rec + 1.0) / 2.0 * (2**self.config.n_bits - 1)
trace_tensor("apc.decode.x_rec", x_rec)
return x_rec
[docs]
def reconstruct(self, x: Tensor, *, mpe: bool = False, tau: float | None = None) -> Tensor:
"""Reconstruct ``x`` by encoding to ``z`` and decoding back to data space."""
z = self.encode(x, mpe=mpe, tau=tau)
return self.decode(z, mpe=mpe, tau=tau)
[docs]
def sample_x(self, num_samples: int, *, tau: float | None = None) -> Tensor:
"""Sample synthetic observations by sampling ``z`` and decoding."""
z = self.sample_z(num_samples=num_samples, tau=tau)
return self.decode(z, mpe=False, tau=tau)
[docs]
def sample_z(self, num_samples: int, *, tau: float | None = None) -> Tensor:
"""Sample latents from the encoder prior."""
tau_eff = self.config.sample_tau if tau is None else tau
return self.encoder.sample_prior_z(num_samples=num_samples, tau=tau_eff)
@staticmethod
def _flatten_tensor(tensor: Tensor) -> Tensor:
"""Flatten all non-batch axes into a single feature axis."""
if tensor.dim() < 2:
raise InvalidParameterError(
f"Expected tensor with batch dimension and at least one feature axis, got shape {tuple(tensor.shape)}."
)
return rearrange(tensor, "b ... -> b (...)")
def _reconstruction_loss(self, x: Tensor, x_rec: Tensor) -> Tensor:
"""Compute reconstruction loss following the reference APC reduction.
For image-like tensors (rank > 2), this applies the same legacy scaling
used in the reference implementation before the reconstruction criterion.
"""
if x.dim() > 2:
# Preserve the reference APC's historical scaling behavior exactly.
# NOTE: This uses n_bits**2 - 1 (not 2**n_bits - 1).
denom = float(self.config.n_bits**2 - 1)
x = x / denom * 2.0 - 1.0
x_rec = x_rec / denom * 2.0 - 1.0
x_flat = self._flatten_tensor(x)
x_rec_flat = self._flatten_tensor(x_rec)
if x_flat.shape != x_rec_flat.shape:
raise InvalidParameterError(
f"Reconstruction shape mismatch: x has {tuple(x_flat.shape)}, x_rec has {tuple(x_rec_flat.shape)}."
)
batch_size = x_flat.shape[0]
if self.config.rec_loss == "mse":
return F.mse_loss(x_rec_flat, x_flat, reduction="sum") / batch_size
if self.config.rec_loss == "bce":
return F.binary_cross_entropy(x_rec_flat, x_flat, reduction="sum") / batch_size
raise InvalidParameterError(f"Unsupported rec_loss '{self.config.rec_loss}'.")
@staticmethod
def _kld_from_stats(stats: LatentStats) -> Tensor:
"""Compute mean KL divergence to standard Normal from moment stats."""
if stats.mu.shape != stats.logvar.shape:
raise InvalidParameterError(
f"Latent stats shape mismatch: mu {tuple(stats.mu.shape)} vs logvar {tuple(stats.logvar.shape)}."
)
reduce_dims = tuple(range(1, stats.mu.dim()))
if len(reduce_dims) == 0:
raise InvalidParameterError("Latent stats must include at least one latent dimension.")
kld_per_sample = 0.5 * (stats.mu.pow(2) + stats.logvar.exp() - 1.0 - stats.logvar).sum(
dim=reduce_dims
)
return kld_per_sample.mean()
[docs]
def loss_components(self, x: Tensor) -> dict[str, Tensor]:
"""Compute APC loss components and intermediate tensors.
Args:
x: Observation tensor.
Returns:
Dictionary with scalar terms ``rec``, ``kld``, ``nll``, ``total``
and helpful intermediates ``z``, ``x_rec``, ``mu``, ``logvar``.
"""
del x
raise UnsupportedOperationError(
"APC KL-style training is unavailable after sample rollback. "
"loss_components() is currently unsupported."
)
[docs]
def loss(self, x: Tensor) -> Tensor:
"""Return only the weighted total APC loss."""
del x
raise UnsupportedOperationError(
"APC KL-style training is unavailable after sample rollback. loss() is currently unsupported."
)
[docs]
def log_likelihood_x(self, x: Tensor) -> Tensor:
"""Compute encoder marginal log-likelihood ``log p(x)`` per sample."""
return self.encoder.log_likelihood_x(x)
[docs]
def joint_log_likelihood(self, x: Tensor, z: Tensor) -> Tensor:
"""Compute encoder joint log-likelihood ``log p(x, z)`` per sample."""
return self.encoder.joint_log_likelihood(x, z)
[docs]
def forward(self, x: Tensor) -> dict[str, Tensor]:
"""Alias for :meth:`loss_components` to integrate with training loops."""
return self.loss_components(x)
def extra_repr(self) -> str:
return (
f"latent_dim={self.config.latent_dim}, rec_loss={self.config.rec_loss}, "
f"weights=(rec={self.config.loss_weights.rec}, kld={self.config.loss_weights.kld}, "
f"nll={self.config.loss_weights.nll})"
)