Source code for spflow.zoo.apc.config

"""Typed configuration objects for Autoencoding Probabilistic Circuits (APC)."""

from __future__ import annotations

from dataclasses import dataclass, field
from typing import Literal

from spflow.exceptions import InvalidParameterError


[docs] @dataclass(frozen=True) class ApcLossWeights: """Weights for the APC training objective terms. Attributes: rec: Weight for the reconstruction loss. kld: Weight for the latent KL term. nll: Weight for the joint negative log-likelihood term. """ rec: float = 1.0 kld: float = 1.0 nll: float = 1.0 def __post_init__(self) -> None: for name, value in (("rec", self.rec), ("kld", self.kld), ("nll", self.nll)): if value < 0.0: raise InvalidParameterError(f"ApcLossWeights.{name} must be >= 0, got {value}.")
[docs] @dataclass(frozen=True) class ApcConfig: """Core APC model configuration. Attributes: latent_dim: Dimensionality of the latent variable block ``Z``. rec_loss: Reconstruction criterion used by :class:`AutoencodingPC`. n_bits: Bit-depth used by reference-style image reconstruction scaling. sample_tau: Temperature for differentiable sampling (SIMPLE/Gumbel style paths). loss_weights: Weights for ``rec``, ``kld``, and ``nll`` objective terms. """ latent_dim: int rec_loss: Literal["mse", "bce"] = "mse" n_bits: int = 8 sample_tau: float = 1.0 loss_weights: ApcLossWeights = field(default_factory=ApcLossWeights) def __post_init__(self) -> None: if self.latent_dim <= 0: raise InvalidParameterError(f"latent_dim must be >= 1, got {self.latent_dim}.") if self.n_bits <= 1: raise InvalidParameterError(f"n_bits must be >= 2, got {self.n_bits}.") if self.sample_tau <= 0.0: raise InvalidParameterError(f"sample_tau must be > 0, got {self.sample_tau}.")
[docs] @dataclass(frozen=True) class ApcTrainConfig: """Configuration for lightweight APC trainer helpers. Attributes: epochs: Number of training epochs. batch_size: Batch size used for tensor-backed training/evaluation inputs. learning_rate: Optimizer learning rate when an optimizer is not provided. weight_decay: Adam weight decay when an optimizer is not provided. grad_clip_norm: Optional gradient clipping threshold (L2 norm). """ epochs: int = 1 batch_size: int = 64 learning_rate: float = 1e-3 weight_decay: float = 0.0 grad_clip_norm: float | None = None def __post_init__(self) -> None: if self.epochs <= 0: raise InvalidParameterError(f"epochs must be >= 1, got {self.epochs}.") if self.batch_size <= 0: raise InvalidParameterError(f"batch_size must be >= 1, got {self.batch_size}.") if self.learning_rate <= 0.0: raise InvalidParameterError(f"learning_rate must be > 0, got {self.learning_rate}.") if self.weight_decay < 0.0: raise InvalidParameterError(f"weight_decay must be >= 0, got {self.weight_decay}.") if self.grad_clip_norm is not None and self.grad_clip_norm <= 0.0: raise InvalidParameterError(f"grad_clip_norm must be > 0 when set, got {self.grad_clip_norm}.")