Einsum Modules¶
Efficient sum-product operations using Einstein summation notation, as described in the EinsumNetworks paper. These layers combine product and sum operations into single efficient einsum operations.
Einet¶
High-level architecture for building Einsum Networks, a scalable deep probabilistic model using EinsumLayer or LinsumLayer for efficient batched computations.
Key parameters:
num_classes: Number of root sum nodes (for classification)num_sums: Number of sum nodes per intermediate layernum_leaves: Number of leaf distribution componentsdepth: Number of einsum layers (determines feature grouping: 2^depth features)num_repetitions: Number of parallel circuit repetitionslayer_type:"einsum"(cross-product) or"linsum"(linear combination)structure:"top-down"or"bottom-up"construction mode
Reference: Peharz, R., et al. (2020). “Einsum Networks: Fast and Scalable Learning of Tractable Probabilistic Circuits.” ICML 2020.
- class spflow.modules.einsum.Einet(leaf_modules, num_classes=1, num_sums=10, num_leaves=10, depth=1, num_repetitions=5, layer_type='linsum', structure='top-down')[source]¶
Bases:
Module,ClassifierEinsum Network (Einet) for scalable deep probabilistic modeling.
Einet uses efficient einsum-based layers (EinsumLayer or LinsumLayer) to combine product and sum operations, enabling faster training and inference compared to traditional RAT-SPNs.
- leaf_modules¶
Leaf distribution modules.
- Type:
list[LeafModule]
- Reference:
Peharz, R., et al. (2020). “Einsum Networks: Fast and Scalable Learning of Tractable Probabilistic Circuits.” ICML 2020.
- __init__(leaf_modules, num_classes=1, num_sums=10, num_leaves=10, depth=1, num_repetitions=5, layer_type='linsum', structure='top-down')[source]¶
Initialize Einet with specified architecture parameters.
- Parameters:
leaf_modules (
list[LeafModule]) – Leaf distribution modules forming the base layer.num_classes (
int) – Number of root sum nodes (classes). Defaults to 1.num_sums (
int) – Number of sum nodes per intermediate layer. Defaults to 10.num_leaves (
int) – Number of leaf distribution components. Defaults to 10.depth (
int) – Number of einsum layers. Defaults to 1.num_repetitions (
int) – Number of parallel circuit repetitions. Defaults to 5.layer_type (
Literal['einsum','linsum']) – Type of intermediate layer (“einsum” or “linsum”). Defaults to “linsum”.structure (
Literal['top-down','bottom-up']) – Structure building mode (“top-down” or “bottom-up”). Defaults to “top-down”.
- Raises:
InvalidParameterError – If architectural parameters are invalid.
- log_posterior(data, cache=None)[source]¶
Compute log-posterior probabilities for multi-class models.
- maximum_likelihood_estimation(data, weights=None, cache=None)[source]¶
Update parameters via maximum likelihood estimation.
- sample(num_samples=None, data=None, is_mpe=False, cache=None, sampling_ctx=None)[source]¶
Generate samples from the Einet.
- Parameters:
- Return type:
- Returns:
Sampled tensor.
- Raises:
NotImplementedError – If structure is “bottom-up” (not yet supported).
EinsumLayer¶
Combines product and sum operations using a cross-product over input channels. Takes pairs of adjacent features as left/right children, computes their cross-product over channels (I × J combinations), and sums with learned weights using the LogEinsumExp trick for numerical stability.
Key characteristics:
Weight shape:
(features, out_channels, repetitions, left_channels, right_channels)Computes cross-product: I × J input channel combinations
Uses LogEinsumExp for numerical stability in log-space
- class spflow.modules.einsum.EinsumLayer(inputs, out_channels, num_repetitions=None, weights=None, split_mode=None)[source]¶
Bases:
ModuleEinsumLayer combining product and sum operations efficiently.
Implements sum(product(x)) using einsum for circuits with arbitrary tree structure. Takes pairs of adjacent features as left/right children, computes their cross-product over channels, and sums with learned weights.
The LogEinsumExp trick is used for numerical stability in log-space.
- logits¶
Unnormalized log-weights for gradient optimization.
- Type:
Parameter
- unraveled_channel_indices¶
Mapping from flat to (i,j) channel pairs.
- Type:
Tensor
- __init__(inputs, out_channels, num_repetitions=None, weights=None, split_mode=None)[source]¶
Initialize EinsumLayer.
- Parameters:
inputs (
Module|list[Module]) – Either a single module (features will be split into pairs) or a list of exactly two modules (left and right children).out_channels (
int) – Number of output sum nodes per feature.num_repetitions (
int|None) – Number of repetitions. If None, inferred from inputs.weights (
Tensor|None) – Optional initial weights tensor. If provided, must have shape (out_features, out_channels, num_repetitions, left_channels, right_channels).split_mode (
SplitMode|None) – Optional split configuration for single input mode. Use SplitMode.consecutive() or SplitMode.interleaved(). Defaults to SplitMode.consecutive(num_splits=2) if not specified.
- Raises:
ValueError – If inputs invalid, out_channels < 1, or weight shape mismatch.
- expectation_maximization(data, bias_correction=True, cache=None)[source]¶
Perform EM step to update weights.
- marginalize(marg_rvs, prune=True, cache=None)[source]¶
Marginalize out specified random variables.
- Parameters:
marg_rvs (list[int]) – Random variable indices to marginalize.
prune (bool) – Whether to prune unnecessary modules.
cache (Cache | None) – Cache for memoization.
- Return type:
Optional[‘EinsumLayer’ | Module]
- Returns:
Marginalized module or None if fully marginalized.
- maximum_likelihood_estimation(data, weights=None, bias_correction=True, nan_strategy='ignore', cache=None)[source]¶
MLE step (equivalent to EM for sum nodes).
- Return type:
- sample(num_samples=None, data=None, is_mpe=False, cache=None, sampling_ctx=None)[source]¶
Sample from the EinsumLayer.
- Parameters:
data (
Tensor|None) – Optional data tensor with evidence (NaN for missing).is_mpe (
bool) – Whether to perform MPE instead of sampling.cache (
Cache|None) – Optional cache with log-likelihoods for conditional sampling.sampling_ctx (
SamplingContext|None) – Sampling context with channel indices.
- Return type:
- Returns:
Sampled data tensor.
LinsumLayer¶
Linear sum-product layer with a simpler linear combination over channels. Unlike EinsumLayer which computes a cross-product (I × J), LinsumLayer pairs left/right features, adds them (product in log-space), then sums over input channels with learned weights.
Key characteristics:
Weight shape:
(features, out_channels, repetitions, in_channels)Linear combination: requires left and right inputs to have matching channel counts
Fewer parameters than EinsumLayer: O(C) vs O(C²)
- class spflow.modules.einsum.LinsumLayer(inputs, out_channels, num_repetitions=None, weights=None, split_mode=None)[source]¶
Bases:
ModuleLinsumLayer combining product and sum operations with linear channel combination.
Unlike EinsumLayer which computes cross-product over channels (I × J combinations), LinsumLayer computes a linear combination: pairs left/right features, adds them (product in log-space), then sums over input channels with learned weights.
This results in fewer parameters: weight_shape = (D_out, O, R, C) vs EinsumLayer’s (D_out, O, R, I, J).
- logits¶
Unnormalized log-weights for gradient optimization.
- Type:
Parameter
- __init__(inputs, out_channels, num_repetitions=None, weights=None, split_mode=None)[source]¶
Initialize LinsumLayer.
- Parameters:
inputs (
Module|list[Module]) – Either a single module (features will be split into pairs) or a list of exactly two modules (left and right children). Unlike EinsumLayer, both inputs must have the same number of channels.out_channels (
int) – Number of output sum nodes per feature.num_repetitions (
int|None) – Number of repetitions. If None, inferred from inputs.weights (
Tensor|None) – Optional initial weights tensor. If provided, must have shape (out_features, out_channels, num_repetitions, in_channels).split_mode (
SplitMode|None) – Optional split configuration for single input mode. Use SplitMode.consecutive() or SplitMode.interleaved(). Defaults to SplitMode.consecutive(num_splits=2) if not specified.
- Raises:
ValueError – If inputs invalid, out_channels < 1, or weight shape mismatch.
- expectation_maximization(data, bias_correction=True, cache=None)[source]¶
Perform EM step to update weights.
- log_likelihood(data, cache=None)[source]¶
Compute log-likelihood using linear sum over channels.
Unlike EinsumLayer which computes cross-product (I × J), this computes a linear combination: add left+right (product), then logsumexp over channels.
- marginalize(marg_rvs, prune=True, cache=None)[source]¶
Marginalize out specified random variables.
- Parameters:
marg_rvs (list[int]) – Random variable indices to marginalize.
prune (bool) – Whether to prune unnecessary modules.
cache (Cache | None) – Cache for memoization.
- Return type:
Optional[‘LinsumLayer’ | Module]
- Returns:
Marginalized module or None if fully marginalized.
- maximum_likelihood_estimation(data, weights=None, bias_correction=True, nan_strategy='ignore', cache=None)[source]¶
MLE step (equivalent to EM for sum nodes).
- Return type:
- sample(num_samples=None, data=None, is_mpe=False, cache=None, sampling_ctx=None)[source]¶
Sample from the LinsumLayer.
- Parameters:
data (
Tensor|None) – Optional data tensor with evidence (NaN for missing).is_mpe (
bool) – Whether to perform MPE instead of sampling.cache (
Cache|None) – Optional cache with log-likelihoods for conditional sampling.sampling_ctx (
SamplingContext|None) – Sampling context with channel indices.
- Return type:
- Returns:
Sampled data tensor.
Comparison¶
Layer |
Weight Shape |
Channel Operation |
Parameter Count |
|---|---|---|---|
EinsumLayer |
(D, O, R, I, J) |
Cross-product I×J |
O(D · O · R · I · J) |
LinsumLayer |
(D, O, R, C) |
Linear sum |
O(D · O · R · C) |
Use EinsumLayer when you need maximum expressiveness with different left/right channel counts. Use LinsumLayer when you want fewer parameters and have matching channel counts.