Continuous Mixtures (CMs)

Continuous Mixtures of Tractable Probabilistic Models introduce a low-dimensional latent variable and a decoder network that outputs parameters of a tractable model, allowing for continuous variations in the model structure.

Reference

Continuous Mixtures are described in:

Overview

The marginal density of a continuous mixture is an integral over the latent space:

\[p(x) = \mathbb{E}_{p(z)}[p(x \mid \phi(z))] = \int p(x \mid \phi(z)) p(z)\,dz\]

SPFlow approximates this integral using Sobol-RQMC (Randomized Quasi-Monte Carlo) points and then compiles the result into a standard SPFlow module for inference.

Key features:

  • Latent Optimization (LO): Supports optimizing latent variables for better data fit.

  • Discrete compilation: Compiled circuits can be used with all standard SPFlow operations.

  • Multiple structures: Supports both factorized (independent) and Chow-Liu tree structures for the components.

Implementation

Factorized Continuous Mixtures

spflow.zoo.cms.learn_continuous_mixture_factorized(data, *, leaf, latent_dim=4, num_points_train=128, num_points_eval=None, num_epochs=300, batch_size=128, lr=0.001, seed=0, device=None, dtype=None, num_cats=None, normal_eps=0.0001, val_data=None, patience=15, lo=None)[source]

Learn a continuous mixture with fully factorized structure S_F.

Parameters:
  • data (Tensor) – Training data of shape (N,F). NaNs are supported and treated as missing values (marginalized out).

  • leaf (Literal['bernoulli', 'categorical', 'normal']) – Leaf distribution family.

  • latent_dim (int) – Latent dimension d.

  • num_points_train (int) – Number of RQMC integration points during training.

  • num_points_eval (int | None) – Number of integration points for evaluation/early stopping. Defaults to num_points_train if None.

  • num_epochs (int) – Number of training epochs.

  • batch_size (int) – Mini-batch size.

  • lr (float) – Learning rate for Adam.

  • seed (int) – Random seed.

  • device (device | None) – Optional device for training.

  • dtype (dtype | None) – Optional dtype for training computations.

  • num_cats (int | None) – Number of categories K for categorical leaves.

  • normal_eps (float) – Minimum scale for Normal leaves.

  • val_data (Tensor | None) – Optional validation data for early stopping.

  • patience (int) – Early stopping patience in epochs.

  • lo (LatentOptimizationConfig | None) – Latent optimization configuration. If None, LO is disabled.

Return type:

Sum

Returns:

A compiled SPFlow module (discrete mixture / Sum) representing the trained model.

Chow–Liu Continuous Mixtures

spflow.zoo.cms.learn_continuous_mixture_cltree(data, *, leaf, latent_dim=4, num_points_train=128, num_points_eval=None, num_epochs=300, batch_size=128, lr=0.001, seed=0, device=None, dtype=None, num_cats=None, val_data=None, patience=15, lo=None, alpha=0.01)[source]

Learn a continuous mixture with Chow–Liu tree structure S_CLT (discrete only).

Notes

  • This learner supports only discrete leaves (Bernoulli / Categorical).

  • Data must be complete (no NaNs) and integer-coded.

Parameters:
  • data (Tensor) – Training data of shape (N,F) with values in {0,..,K-1}.

  • leaf (Literal['bernoulli', 'categorical']) – Discrete leaf family.

  • latent_dim (int) – Latent dimension d.

  • num_points_train (int) – Number of RQMC integration points during training.

  • num_points_eval (int | None) – Number of integration points for evaluation/early stopping. Defaults to num_points_train if None.

  • num_epochs (int) – Number of training epochs.

  • batch_size (int) – Mini-batch size.

  • lr (float) – Learning rate for Adam.

  • seed (int) – Random seed.

  • device (device | None) – Optional device for training.

  • dtype (dtype | None) – Optional dtype for training computations.

  • num_cats (int | None) – K for categorical leaves. Ignored for Bernoulli (K=2).

  • val_data (Tensor | None) – Optional validation data for early stopping.

  • patience (int) – Early stopping patience in epochs.

  • lo (LatentOptimizationConfig | None) – Latent optimization configuration. If None, LO is disabled.

  • alpha (float) – CLTree pseudocount used at compile time.

Return type:

JointLogLikelihood

Returns:

A compiled SPFlow module representing the trained model, wrapped so that log_likelihood returns a single feature (joint score).