Source code for spflow.modules.einsum.einet

"""Einet (EinsumNetworks) module for efficient deep probabilistic models.

Einet provides a scalable architecture for Sum-Product Networks using
EinsumLayer or LinsumLayer for efficient batched computations.

Reference:
    Peharz, R., et al. (2020). "Einsum Networks: Fast and Scalable Learning of
    Tractable Probabilistic Circuits." ICML 2020.
"""

from __future__ import annotations

from typing import Literal, Optional

import numpy as np
import torch
from torch import nn

from spflow.exceptions import InvalidParameterError, UnsupportedOperationError
from spflow.interfaces.classifier import Classifier
from spflow.meta.data.scope import Scope
from spflow.modules.einsum.einsum_layer import EinsumLayer
from spflow.modules.einsum.linsum_layer import LinsumLayer
from spflow.modules.leaves.leaf import LeafModule
from spflow.modules.module import Module
from spflow.modules.module_shape import ModuleShape
from spflow.modules.rat.factorize import Factorize
from spflow.modules.sums.repetition_mixing_layer import RepetitionMixingLayer
from spflow.modules.sums.sum import Sum
from spflow.utils.cache import Cache, cached
from spflow.utils.inference import log_posterior
from spflow.utils.sampling_context import SamplingContext, init_default_sampling_context


[docs] class Einet(Module, Classifier): """Einsum Network (Einet) for scalable deep probabilistic modeling. Einet uses efficient einsum-based layers (EinsumLayer or LinsumLayer) to combine product and sum operations, enabling faster training and inference compared to traditional RAT-SPNs. Attributes: leaf_modules (list[LeafModule]): Leaf distribution modules. num_classes (int): Number of output classes (root sum nodes). num_sums (int): Number of sum nodes per intermediate layer. num_leaves (int): Number of leaf distribution components. depth (int): Number of einsum layers. num_repetitions (int): Number of parallel circuit repetitions. layer_type (str): Type of intermediate layer ("einsum" or "linsum"). structure (str): Structure building mode ("top-down" or "bottom-up"). Reference: Peharz, R., et al. (2020). "Einsum Networks: Fast and Scalable Learning of Tractable Probabilistic Circuits." ICML 2020. """
[docs] def __init__( self, leaf_modules: list[LeafModule], num_classes: int = 1, num_sums: int = 10, num_leaves: int = 10, depth: int = 1, num_repetitions: int = 5, layer_type: Literal["einsum", "linsum"] = "linsum", structure: Literal["top-down", "bottom-up"] = "top-down", ) -> None: """Initialize Einet with specified architecture parameters. Args: leaf_modules: Leaf distribution modules forming the base layer. num_classes: Number of root sum nodes (classes). Defaults to 1. num_sums: Number of sum nodes per intermediate layer. Defaults to 10. num_leaves: Number of leaf distribution components. Defaults to 10. depth: Number of einsum layers. Defaults to 1. num_repetitions: Number of parallel circuit repetitions. Defaults to 5. layer_type: Type of intermediate layer ("einsum" or "linsum"). Defaults to "linsum". structure: Structure building mode ("top-down" or "bottom-up"). Defaults to "top-down". Raises: InvalidParameterError: If architectural parameters are invalid. """ super().__init__() # Validate parameters if num_classes < 1: raise InvalidParameterError(f"num_classes must be >= 1, got {num_classes}") if num_sums < 1: raise InvalidParameterError(f"num_sums must be >= 1, got {num_sums}") if num_leaves < 1: raise InvalidParameterError(f"num_leaves must be >= 1, got {num_leaves}") if depth < 0: raise InvalidParameterError(f"depth must be >= 0, got {depth}") if num_repetitions < 1: raise InvalidParameterError(f"num_repetitions must be >= 1, got {num_repetitions}") if layer_type not in ("einsum", "linsum"): raise InvalidParameterError(f"layer_type must be 'einsum' or 'linsum', got {layer_type}") if structure not in ("top-down", "bottom-up"): raise InvalidParameterError(f"structure must be 'top-down' or 'bottom-up', got {structure}") # Store configuration self.leaf_modules = nn.ModuleList(leaf_modules) self.num_classes = num_classes self.num_sums = num_sums self.num_leaves = num_leaves self.depth = depth self.num_repetitions = num_repetitions self.layer_type = layer_type self.structure = structure # Compute scope from leaf modules self.scope = Scope.join_all([leaf.scope for leaf in leaf_modules]) self.num_features = len(self.scope.query) # Validate depth against number of features if 2**depth > self.num_features: raise InvalidParameterError( f"depth {depth} too large for {self.num_features} features. " f"Maximum depth is {int(np.floor(np.log2(self.num_features)))}." ) # Build the architecture if structure == "top-down": self._build_structure_top_down() else: self._build_structure_bottom_up() # Shape computation self.in_shape = self.root_node.in_shape self.out_shape = self.root_node.out_shape
def _get_layer_class(self): """Get the layer class based on layer_type.""" if self.layer_type == "einsum": return EinsumLayer else: return LinsumLayer def _build_structure_top_down(self) -> None: """Build Einet structure from top (root) to bottom (leaves). In top-down mode, we define layers starting from the root and work down to the leaves. Each layer i has 2^i input features. """ LayerClass = self._get_layer_class() layers: list[Module] = [] # Build layers from top (i=1) to bottom (i=depth) for i in range(1, self.depth + 1): # Number of input channels if i < self.depth: in_channels = self.num_sums else: in_channels = self.num_leaves # Number of output channels if i == 1: out_channels = self.num_classes else: out_channels = self.num_sums # Number of features at this layer in_features = 2**i # Create placeholder input with correct shape for layer construction # We'll connect them properly after building all layers layers.append( { "in_features": in_features, "in_channels": in_channels, "out_channels": out_channels, } ) # Handle depth=0 case: single sum layer if self.depth == 0: # Create factorized leaves with single output feature fac_layer = Factorize( inputs=list(self.leaf_modules), depth=0, num_repetitions=self.num_repetitions, ) # Single sum layer from leaves to root root = Sum( inputs=fac_layer, out_channels=self.num_classes, num_repetitions=self.num_repetitions, ) else: # Create factorized leaves leaf_num_features_out = 2**self.depth fac_layer = Factorize( inputs=list(self.leaf_modules), depth=self.depth, num_repetitions=self.num_repetitions, ) # Build layers bottom-up (reverse of how they process data) current = fac_layer for i in range(self.depth, 0, -1): layer_info = layers[i - 1] # Create einsum/linsum layer current = LayerClass( inputs=current, out_channels=layer_info["out_channels"], num_repetitions=self.num_repetitions, ) root = current # Mix repetitions if we have multiple if self.num_repetitions > 1: root = RepetitionMixingLayer( inputs=root, out_channels=self.num_classes, num_repetitions=self.num_repetitions, ) # Final root sum if multiple classes if self.num_classes > 1 and not isinstance(root, RepetitionMixingLayer): self.root_node = Sum(inputs=root, out_channels=1, num_repetitions=1) else: self.root_node = root # Store layers for access self.factorize = fac_layer def _build_structure_bottom_up(self) -> None: """Build Einet structure from bottom (leaves) to top (root). In bottom-up mode, we start with the full feature set and progressively halve features at each layer. """ LayerClass = self._get_layer_class() # Create factorized leaves (no depth reduction, just random permutations) fac_layer = Factorize( inputs=list(self.leaf_modules), depth=int(np.log2(self.num_features)) if self.num_features > 1 else 0, num_repetitions=self.num_repetitions, ) # Generate random permutations for each repetition permutations = torch.empty((self.num_repetitions, self.num_features), dtype=torch.long) for r in range(self.num_repetitions): permutations[r] = torch.randperm(self.num_features) self.register_buffer("permutation", permutations) # Build layers from leaves to root current = fac_layer in_features = fac_layer.out_shape.features for i in range(self.depth): # Create einsum/linsum layer current = LayerClass( inputs=current, out_channels=self.num_sums, num_repetitions=self.num_repetitions, ) in_features = current.out_shape.features # Handle depth=0 case if self.depth == 0: # Single sum layer current = Sum( inputs=fac_layer, out_channels=self.num_sums, num_repetitions=self.num_repetitions, ) # Add sum layer to get features down to 1 if needed if current.out_shape.features > 1: current = Sum( inputs=current, out_channels=self.num_sums, num_repetitions=self.num_repetitions, ) # Add final sum layer to convert from num_sums to num_classes if current.out_shape.channels != self.num_classes: current = Sum( inputs=current, out_channels=self.num_classes, num_repetitions=self.num_repetitions, ) # Mix repetitions if we have multiple if self.num_repetitions > 1: root = RepetitionMixingLayer( inputs=current, out_channels=self.num_classes, num_repetitions=self.num_repetitions, ) else: root = current self.root_node = root # Store layers for access self.factorize = fac_layer @property def n_out(self) -> int: """Number of output nodes.""" return 1 @property def feature_to_scope(self) -> np.ndarray: """Mapping from output features to their scopes.""" return self.root_node.feature_to_scope @property def scopes_out(self) -> list[Scope]: """Output scopes.""" return self.root_node.scopes_out
[docs] @cached def log_likelihood( self, data: torch.Tensor, cache: Cache | None = None, ) -> torch.Tensor: """Compute log-likelihood for input data. Args: data: Input data tensor of shape (batch_size, num_features). cache: Optional cache for intermediate results. Returns: Log-likelihood tensor of shape (batch_size, 1, num_classes, 1). """ # Apply permutation if in bottom-up mode if hasattr(self, "permutation") and self.structure == "bottom-up": # Permute features for each repetition # This is handled inside the factorize layer via its indices pass return self.root_node.log_likelihood(data, cache=cache)
[docs] def log_posterior( self, data: torch.Tensor, cache: Cache | None = None, ) -> torch.Tensor: """Compute log-posterior probabilities for multi-class models. Args: data: Input data tensor. cache: Optional cache for intermediate results. Returns: Log-posterior probabilities of shape (batch_size, num_classes). Raises: UnsupportedOperationError: If model has only one class. """ if self.num_classes <= 1: raise UnsupportedOperationError( "Posterior can only be computed for models with multiple classes." ) ll_y = self.root_node.log_weights ll_y = ll_y.squeeze().view(1, -1) ll = self.root_node.inputs.log_likelihood(data, cache=cache) ll = ll.squeeze(-1).squeeze(1) return log_posterior(log_likelihood=ll, log_prior=ll_y)
[docs] def predict_proba(self, data: torch.Tensor) -> torch.Tensor: """Predict class probabilities. Args: data: Input data tensor. Returns: Class probabilities of shape (batch_size, num_classes). """ log_post = self.log_posterior(data) return torch.exp(log_post)
[docs] def sample( self, num_samples: int | None = None, data: torch.Tensor | None = None, is_mpe: bool = False, cache: Cache | None = None, sampling_ctx: SamplingContext | None = None, ) -> torch.Tensor: """Generate samples from the Einet. Args: num_samples: Number of samples to generate. data: Optional data tensor with NaN values to impute. is_mpe: Whether to perform MPE (most probable explanation). cache: Optional cache for intermediate results. sampling_ctx: Optional sampling context. Returns: Sampled tensor. Raises: NotImplementedError: If structure is "bottom-up" (not yet supported). """ # Bottom-up sampling not yet supported due to shape propagation complexity if self.structure == "bottom-up": raise NotImplementedError( "Sampling from bottom-up Einet structure is not yet implemented. " "Use structure='top-down' for sampling, or use log_likelihood() which " "works for both structures." ) # Handle num_samples case if data is None: if num_samples is None: num_samples = 1 data = torch.full((num_samples, self.num_features), torch.nan, device=self.device) batch_size = data.shape[0] # Initialize sampling context if sampling_ctx is None: sampling_ctx = init_default_sampling_context(None, batch_size, data.device) # Always initialize repetition_idx (required by Factorize.sample()) if sampling_ctx.repetition_idx is None: if self.num_repetitions == 1: # Single repetition: use index 0 for all samples sampling_ctx.repetition_idx = torch.zeros(batch_size, dtype=torch.long, device=data.device) # For num_repetitions > 1, RepetitionMixingLayer will set this # Handle class sampling for multi-class models if self.num_classes > 1: logits = self.root_node.logits if logits.shape != (1, self.num_classes, 1): raise InvalidParameterError( f"Expected logits shape (1, {self.num_classes}, 1), got {logits.shape}" ) logits = logits.squeeze(-1) logits = logits.unsqueeze(0).expand(batch_size, -1, -1) if is_mpe: sampling_ctx.channel_index = torch.argmax(logits, dim=-1) else: sampling_ctx.channel_index = torch.distributions.Categorical(logits=logits).sample() # Sample from appropriate root if self.num_classes > 1: sample_root = self.root_node.inputs else: sample_root = self.root_node return sample_root.sample( data=data, is_mpe=is_mpe, cache=cache, sampling_ctx=sampling_ctx, )
[docs] def expectation_maximization( self, data: torch.Tensor, cache: Cache | None = None, ) -> None: """Perform expectation-maximization step. Args: data: Input data tensor. cache: Optional cache with log-likelihoods. """ self.root_node.expectation_maximization(data, cache=cache)
[docs] def maximum_likelihood_estimation( self, data: torch.Tensor, weights: torch.Tensor | None = None, cache: Cache | None = None, ) -> None: """Update parameters via maximum likelihood estimation. Args: data: Input data tensor. weights: Optional sample weights. cache: Optional cache. """ self.root_node.maximum_likelihood_estimation(data, weights=weights, cache=cache)
[docs] def marginalize( self, marg_rvs: list[int], prune: bool = True, cache: Cache | None = None, ) -> Module | None: """Marginalize out specified random variables. Args: marg_rvs: Random variable indices to marginalize. prune: Whether to prune redundant modules. cache: Optional cache. Returns: Marginalized module or None if fully marginalized. """ return self.root_node.marginalize(marg_rvs, prune=prune, cache=cache)
def extra_repr(self) -> str: """String representation of module configuration.""" return ( f"num_features={self.num_features}, num_classes={self.num_classes}, " f"num_sums={self.num_sums}, num_leaves={self.num_leaves}, " f"depth={self.depth}, num_repetitions={self.num_repetitions}, " f"layer_type={self.layer_type}, structure={self.structure}" )