Einsum Modules

Efficient sum-product operations using Einstein summation notation, as described in the EinsumNetworks paper. These layers combine product and sum operations into single efficient einsum operations.

Einet

High-level architecture for building Einsum Networks, a scalable deep probabilistic model using EinsumLayer or LinsumLayer for efficient batched computations.

Key parameters:

  • num_classes: Number of root sum nodes (for classification)

  • num_sums: Number of sum nodes per intermediate layer

  • num_leaves: Number of leaf distribution components

  • depth: Number of einsum layers (determines feature grouping: 2^depth features)

  • num_repetitions: Number of parallel circuit repetitions

  • layer_type: "einsum" (cross-product) or "linsum" (linear combination)

  • structure: "top-down" or "bottom-up" construction mode

Reference: Peharz, R., et al. (2020). “Einsum Networks: Fast and Scalable Learning of Tractable Probabilistic Circuits.” ICML 2020.

class spflow.modules.einsum.Einet(leaf_modules, num_classes=1, num_sums=10, num_leaves=10, depth=1, num_repetitions=5, layer_type='linsum', structure='top-down')[source]

Bases: Module, Classifier

Einsum Network (Einet) for scalable deep probabilistic modeling.

Einet uses efficient einsum-based layers (EinsumLayer or LinsumLayer) to combine product and sum operations, enabling faster training and inference compared to traditional RAT-SPNs.

leaf_modules

Leaf distribution modules.

Type:

list[LeafModule]

num_classes

Number of output classes (root sum nodes).

Type:

int

num_sums

Number of sum nodes per intermediate layer.

Type:

int

num_leaves

Number of leaf distribution components.

Type:

int

depth

Number of einsum layers.

Type:

int

num_repetitions

Number of parallel circuit repetitions.

Type:

int

layer_type

Type of intermediate layer (“einsum” or “linsum”).

Type:

str

structure

Structure building mode (“top-down” or “bottom-up”).

Type:

str

Reference:

Peharz, R., et al. (2020). “Einsum Networks: Fast and Scalable Learning of Tractable Probabilistic Circuits.” ICML 2020.

__init__(leaf_modules, num_classes=1, num_sums=10, num_leaves=10, depth=1, num_repetitions=5, layer_type='linsum', structure='top-down')[source]

Initialize Einet with specified architecture parameters.

Parameters:
  • leaf_modules (list[LeafModule]) – Leaf distribution modules forming the base layer.

  • num_classes (int) – Number of root sum nodes (classes). Defaults to 1.

  • num_sums (int) – Number of sum nodes per intermediate layer. Defaults to 10.

  • num_leaves (int) – Number of leaf distribution components. Defaults to 10.

  • depth (int) – Number of einsum layers. Defaults to 1.

  • num_repetitions (int) – Number of parallel circuit repetitions. Defaults to 5.

  • layer_type (Literal['einsum', 'linsum']) – Type of intermediate layer (“einsum” or “linsum”). Defaults to “linsum”.

  • structure (Literal['top-down', 'bottom-up']) – Structure building mode (“top-down” or “bottom-up”). Defaults to “top-down”.

Raises:

InvalidParameterError – If architectural parameters are invalid.

expectation_maximization(data, cache=None)[source]

Perform expectation-maximization step.

Parameters:
  • data (Tensor) – Input data tensor.

  • cache (Cache | None) – Optional cache with log-likelihoods.

Return type:

None

log_likelihood(data, cache=None)[source]

Compute log-likelihood for input data.

Parameters:
  • data (Tensor) – Input data tensor of shape (batch_size, num_features).

  • cache (Cache | None) – Optional cache for intermediate results.

Return type:

Tensor

Returns:

Log-likelihood tensor of shape (batch_size, 1, num_classes, 1).

log_posterior(data, cache=None)[source]

Compute log-posterior probabilities for multi-class models.

Parameters:
  • data (Tensor) – Input data tensor.

  • cache (Cache | None) – Optional cache for intermediate results.

Return type:

Tensor

Returns:

Log-posterior probabilities of shape (batch_size, num_classes).

Raises:

UnsupportedOperationError – If model has only one class.

marginalize(marg_rvs, prune=True, cache=None)[source]

Marginalize out specified random variables.

Parameters:
  • marg_rvs (list[int]) – Random variable indices to marginalize.

  • prune (bool) – Whether to prune redundant modules.

  • cache (Cache | None) – Optional cache.

Return type:

Module | None

Returns:

Marginalized module or None if fully marginalized.

maximum_likelihood_estimation(data, weights=None, cache=None)[source]

Update parameters via maximum likelihood estimation.

Parameters:
  • data (Tensor) – Input data tensor.

  • weights (Tensor | None) – Optional sample weights.

  • cache (Cache | None) – Optional cache.

Return type:

None

predict_proba(data)[source]

Predict class probabilities.

Parameters:

data (Tensor) – Input data tensor.

Return type:

Tensor

Returns:

Class probabilities of shape (batch_size, num_classes).

sample(num_samples=None, data=None, is_mpe=False, cache=None, sampling_ctx=None)[source]

Generate samples from the Einet.

Parameters:
  • num_samples (int | None) – Number of samples to generate.

  • data (Tensor | None) – Optional data tensor with NaN values to impute.

  • is_mpe (bool) – Whether to perform MPE (most probable explanation).

  • cache (Cache | None) – Optional cache for intermediate results.

  • sampling_ctx (SamplingContext | None) – Optional sampling context.

Return type:

Tensor

Returns:

Sampled tensor.

Raises:

NotImplementedError – If structure is “bottom-up” (not yet supported).

property feature_to_scope: ndarray

Mapping from output features to their scopes.

property n_out: int

Number of output nodes.

property scopes_out: list[Scope]

Output scopes.

EinsumLayer

Combines product and sum operations using a cross-product over input channels. Takes pairs of adjacent features as left/right children, computes their cross-product over channels (I × J combinations), and sums with learned weights using the LogEinsumExp trick for numerical stability.

Key characteristics:

  • Weight shape: (features, out_channels, repetitions, left_channels, right_channels)

  • Computes cross-product: I × J input channel combinations

  • Uses LogEinsumExp for numerical stability in log-space

class spflow.modules.einsum.EinsumLayer(inputs, out_channels, num_repetitions=None, weights=None, split_mode=None)[source]

Bases: Module

EinsumLayer combining product and sum operations efficiently.

Implements sum(product(x)) using einsum for circuits with arbitrary tree structure. Takes pairs of adjacent features as left/right children, computes their cross-product over channels, and sums with learned weights.

The LogEinsumExp trick is used for numerical stability in log-space.

logits

Unnormalized log-weights for gradient optimization.

Type:

Parameter

unraveled_channel_indices

Mapping from flat to (i,j) channel pairs.

Type:

Tensor

__init__(inputs, out_channels, num_repetitions=None, weights=None, split_mode=None)[source]

Initialize EinsumLayer.

Parameters:
  • inputs (Module | list[Module]) – Either a single module (features will be split into pairs) or a list of exactly two modules (left and right children).

  • out_channels (int) – Number of output sum nodes per feature.

  • num_repetitions (int | None) – Number of repetitions. If None, inferred from inputs.

  • weights (Tensor | None) – Optional initial weights tensor. If provided, must have shape (out_features, out_channels, num_repetitions, left_channels, right_channels).

  • split_mode (SplitMode | None) – Optional split configuration for single input mode. Use SplitMode.consecutive() or SplitMode.interleaved(). Defaults to SplitMode.consecutive(num_splits=2) if not specified.

Raises:

ValueError – If inputs invalid, out_channels < 1, or weight shape mismatch.

expectation_maximization(data, bias_correction=True, cache=None)[source]

Perform EM step to update weights.

Parameters:
  • data (Tensor) – Training data tensor.

  • bias_correction (bool) – Whether to apply bias correction.

  • cache (Cache | None) – Cache with log-likelihoods.

Return type:

None

log_likelihood(data, cache=None)[source]

Compute log-likelihood using LogEinsumExp trick.

Parameters:
  • data (Tensor) – Input data of shape (batch_size, num_features).

  • cache (Cache | None) – Optional cache for intermediate results.

Return type:

Tensor

Returns:

Log-likelihood tensor of shape (batch, out_features, out_channels, reps).

marginalize(marg_rvs, prune=True, cache=None)[source]

Marginalize out specified random variables.

Parameters:
  • marg_rvs (list[int]) – Random variable indices to marginalize.

  • prune (bool) – Whether to prune unnecessary modules.

  • cache (Cache | None) – Cache for memoization.

Return type:

Optional[‘EinsumLayer’ | Module]

Returns:

Marginalized module or None if fully marginalized.

maximum_likelihood_estimation(data, weights=None, bias_correction=True, nan_strategy='ignore', cache=None)[source]

MLE step (equivalent to EM for sum nodes).

Return type:

None

sample(num_samples=None, data=None, is_mpe=False, cache=None, sampling_ctx=None)[source]

Sample from the EinsumLayer.

Parameters:
  • num_samples (int | None) – Number of samples to generate.

  • data (Tensor | None) – Optional data tensor with evidence (NaN for missing).

  • is_mpe (bool) – Whether to perform MPE instead of sampling.

  • cache (Cache | None) – Optional cache with log-likelihoods for conditional sampling.

  • sampling_ctx (SamplingContext | None) – Sampling context with channel indices.

Return type:

Tensor

Returns:

Sampled data tensor.

property feature_to_scope: ndarray

Mapping from output features to their scopes.

property log_weights: Tensor

Log-normalized weights (sum to 1 over input channel pairs).

property weights: Tensor

Normalized weights (sum to 1 over input channel pairs).

LinsumLayer

Linear sum-product layer with a simpler linear combination over channels. Unlike EinsumLayer which computes a cross-product (I × J), LinsumLayer pairs left/right features, adds them (product in log-space), then sums over input channels with learned weights.

Key characteristics:

  • Weight shape: (features, out_channels, repetitions, in_channels)

  • Linear combination: requires left and right inputs to have matching channel counts

  • Fewer parameters than EinsumLayer: O(C) vs O(C²)

class spflow.modules.einsum.LinsumLayer(inputs, out_channels, num_repetitions=None, weights=None, split_mode=None)[source]

Bases: Module

LinsumLayer combining product and sum operations with linear channel combination.

Unlike EinsumLayer which computes cross-product over channels (I × J combinations), LinsumLayer computes a linear combination: pairs left/right features, adds them (product in log-space), then sums over input channels with learned weights.

This results in fewer parameters: weight_shape = (D_out, O, R, C) vs EinsumLayer’s (D_out, O, R, I, J).

logits

Unnormalized log-weights for gradient optimization.

Type:

Parameter

__init__(inputs, out_channels, num_repetitions=None, weights=None, split_mode=None)[source]

Initialize LinsumLayer.

Parameters:
  • inputs (Module | list[Module]) – Either a single module (features will be split into pairs) or a list of exactly two modules (left and right children). Unlike EinsumLayer, both inputs must have the same number of channels.

  • out_channels (int) – Number of output sum nodes per feature.

  • num_repetitions (int | None) – Number of repetitions. If None, inferred from inputs.

  • weights (Tensor | None) – Optional initial weights tensor. If provided, must have shape (out_features, out_channels, num_repetitions, in_channels).

  • split_mode (SplitMode | None) – Optional split configuration for single input mode. Use SplitMode.consecutive() or SplitMode.interleaved(). Defaults to SplitMode.consecutive(num_splits=2) if not specified.

Raises:

ValueError – If inputs invalid, out_channels < 1, or weight shape mismatch.

expectation_maximization(data, bias_correction=True, cache=None)[source]

Perform EM step to update weights.

Parameters:
  • data (Tensor) – Training data tensor.

  • bias_correction (bool) – Whether to apply bias correction.

  • cache (Cache | None) – Cache with log-likelihoods.

Return type:

None

log_likelihood(data, cache=None)[source]

Compute log-likelihood using linear sum over channels.

Unlike EinsumLayer which computes cross-product (I × J), this computes a linear combination: add left+right (product), then logsumexp over channels.

Parameters:
  • data (Tensor) – Input data of shape (batch_size, num_features).

  • cache (Cache | None) – Optional cache for intermediate results.

Return type:

Tensor

Returns:

Log-likelihood tensor of shape (batch, out_features, out_channels, reps).

marginalize(marg_rvs, prune=True, cache=None)[source]

Marginalize out specified random variables.

Parameters:
  • marg_rvs (list[int]) – Random variable indices to marginalize.

  • prune (bool) – Whether to prune unnecessary modules.

  • cache (Cache | None) – Cache for memoization.

Return type:

Optional[‘LinsumLayer’ | Module]

Returns:

Marginalized module or None if fully marginalized.

maximum_likelihood_estimation(data, weights=None, bias_correction=True, nan_strategy='ignore', cache=None)[source]

MLE step (equivalent to EM for sum nodes).

Return type:

None

sample(num_samples=None, data=None, is_mpe=False, cache=None, sampling_ctx=None)[source]

Sample from the LinsumLayer.

Parameters:
  • num_samples (int | None) – Number of samples to generate.

  • data (Tensor | None) – Optional data tensor with evidence (NaN for missing).

  • is_mpe (bool) – Whether to perform MPE instead of sampling.

  • cache (Cache | None) – Optional cache with log-likelihoods for conditional sampling.

  • sampling_ctx (SamplingContext | None) – Sampling context with channel indices.

Return type:

Tensor

Returns:

Sampled data tensor.

property feature_to_scope: ndarray

Mapping from output features to their scopes.

property log_weights: Tensor

Log-normalized weights (sum to 1 over input channels).

property weights: Tensor

Normalized weights (sum to 1 over input channels).

Comparison

Layer

Weight Shape

Channel Operation

Parameter Count

EinsumLayer

(D, O, R, I, J)

Cross-product I×J

O(D · O · R · I · J)

LinsumLayer

(D, O, R, C)

Linear sum

O(D · O · R · C)

Use EinsumLayer when you need maximum expressiveness with different left/right channel counts. Use LinsumLayer when you want fewer parameters and have matching channel counts.