Einsum Networks (Einet)¶
Einsum Networks (Einets) are a scalable class of probabilistic circuits that use Einstein summation notation (einsum) to implement efficient sum-product operations in parallel.
Reference¶
Einets are described in the ICML 2020 paper:
Overview¶
Einet provides a scalable architecture for Sum-Product Networks using EinsumLayer or LinsumLayer for efficient batched computations. These layers combine product and sum operations into single efficient einsum operations.
Key Characteristics:¶
Efficient batched computations: Leverage PyTorch’s optimized
einsumimplementation.Scalable deep architecture: Supports deep stacks of einsum/linsum layers.
Fast inference and sampling: Optimized for high-throughput probabilistic modeling.
Implementation¶
The Einet implementation in SPFlow provides a high-level spflow.zoo.einet.Einet module.
- class spflow.zoo.einet.Einet(leaf_modules, num_classes=1, num_sums=10, num_leaves=10, depth=1, num_repetitions=5, layer_type='linsum', structure='top-down')[source]¶
Bases:
Module,ClassifierEinsum Network (Einet) for scalable deep probabilistic modeling.
Einet uses efficient einsum-based layers (EinsumLayer or LinsumLayer) to combine product and sum operations, enabling faster training and inference compared to traditional RAT-SPNs.
- leaf_modules¶
Leaf distribution modules.
- Type:
list[LeafModule]
- Reference:
Peharz, R., et al. (2020). “Einsum Networks: Fast and Scalable Learning of Tractable Probabilistic Circuits.” ICML 2020.
- __init__(leaf_modules, num_classes=1, num_sums=10, num_leaves=10, depth=1, num_repetitions=5, layer_type='linsum', structure='top-down')[source]¶
Initialize Einet with specified architecture parameters.
- Parameters:
leaf_modules (
list[LeafModule]) – Leaf distribution modules forming the base layer.num_classes (
int) – Number of root sum nodes (classes). Defaults to 1.num_sums (
int) – Number of sum nodes per intermediate layer. Defaults to 10.num_leaves (
int) – Number of leaf distribution components. Defaults to 10.depth (
int) – Number of einsum layers. Defaults to 1.num_repetitions (
int) – Number of parallel circuit repetitions. Defaults to 5.layer_type (
Literal['einsum','linsum']) – Type of intermediate layer (“einsum” or “linsum”). Defaults to “linsum”.structure (
Literal['top-down','bottom-up']) – Structure building mode (“top-down” or “bottom-up”). Defaults to “top-down”.
- Raises:
InvalidParameterError – If architectural parameters are invalid.
- log_posterior(data, cache=None)[source]¶
Compute log-posterior probabilities for multi-class models.
- sample(num_samples=None, data=None, is_mpe=False, cache=None)[source]¶
Generate samples from the Einet.
- Parameters:
- Return type:
- Returns:
Sampled tensor.
- Raises:
NotImplementedError – If structure is “bottom-up” (not yet supported).
Layers¶
- class spflow.modules.einsum.EinsumLayer(inputs, out_channels, num_repetitions=None, weights=None, split_mode=None)[source]¶
Bases:
ModuleEinsumLayer combining product and sum operations efficiently.
Implements sum(product(x)) using einsum for circuits with arbitrary tree structure. Takes pairs of adjacent features as left/right children, computes their cross-product over channels, and sums with learned weights.
The LogEinsumExp trick is used for numerical stability in log-space.
- logits¶
Unnormalized log-weights for gradient optimization.
- Type:
Parameter
- unraveled_channel_indices¶
Mapping from flat to (i,j) channel pairs.
- Type:
Tensor
- __init__(inputs, out_channels, num_repetitions=None, weights=None, split_mode=None)[source]¶
Initialize EinsumLayer.
- Parameters:
inputs (
Module|list[Module]) – Either a single module (features will be split into pairs) or a list of exactly two modules (left and right children).out_channels (
int) – Number of output sum nodes per feature.num_repetitions (
int|None) – Number of repetitions. If None, inferred from inputs.weights (
Tensor|None) – Optional initial weights tensor. If provided, must have shape (out_features, out_channels, num_repetitions, left_channels, right_channels).split_mode (
SplitMode|None) – Optional split configuration for single input mode. Use SplitMode.consecutive() or SplitMode.interleaved(). Defaults to SplitMode.consecutive(num_splits=2) if not specified.
- Raises:
ValueError – If inputs invalid, out_channels < 1, or weight shape mismatch.
- marginalize(marg_rvs, prune=True, cache=None)[source]¶
Marginalize out specified random variables.
- Parameters:
marg_rvs (list[int]) – Random variable indices to marginalize.
prune (bool) – Whether to prune unnecessary modules.
cache (Cache | None) – Cache for memoization.
- Return type:
Optional[‘EinsumLayer’ | Module]
- Returns:
Marginalized module or None if fully marginalized.
- class spflow.modules.einsum.LinsumLayer(inputs, out_channels, num_repetitions=None, weights=None, split_mode=None)[source]¶
Bases:
ModuleLinsumLayer combining product and sum operations with linear channel combination.
Unlike EinsumLayer which computes cross-product over channels (I × J combinations), LinsumLayer computes a linear combination: pairs left/right features, adds them (product in log-space), then sums over input channels with learned weights.
This results in fewer parameters: weight_shape = (D_out, O, R, C) vs EinsumLayer’s (D_out, O, R, I, J).
- logits¶
Unnormalized log-weights for gradient optimization.
- Type:
Parameter
- __init__(inputs, out_channels, num_repetitions=None, weights=None, split_mode=None)[source]¶
Initialize LinsumLayer.
- Parameters:
inputs (
Module|list[Module]) – Either a single module (features will be split into pairs) or a list of exactly two modules (left and right children). Unlike EinsumLayer, both inputs must have the same number of channels.out_channels (
int) – Number of output sum nodes per feature.num_repetitions (
int|None) – Number of repetitions. If None, inferred from inputs.weights (
Tensor|None) – Optional initial weights tensor. If provided, must have shape (out_features, out_channels, num_repetitions, in_channels).split_mode (
SplitMode|None) – Optional split configuration for single input mode. Use SplitMode.consecutive() or SplitMode.interleaved(). Defaults to SplitMode.consecutive(num_splits=2) if not specified.
- Raises:
ValueError – If inputs invalid, out_channels < 1, or weight shape mismatch.
- log_likelihood(data, cache=None)[source]¶
Compute log-likelihood using linear sum over channels.
Unlike EinsumLayer which computes cross-product (I × J), this computes a linear combination: add left+right (product), then logsumexp over channels.
- marginalize(marg_rvs, prune=True, cache=None)[source]¶
Marginalize out specified random variables.
- Parameters:
marg_rvs (list[int]) – Random variable indices to marginalize.
prune (bool) – Whether to prune unnecessary modules.
cache (Cache | None) – Cache for memoization.
- Return type:
Optional[‘LinsumLayer’ | Module]
- Returns:
Marginalized module or None if fully marginalized.