Source code for spflow.modules.einsum.einsum_layer

"""EinsumLayer for efficient sum-product operations in probabilistic circuits.

Implements the EinsumLayer as described in the Einet paper, combining product
and sum operations into a single efficient einsum operation using the
LogEinsumExp trick for numerical stability.
"""

from __future__ import annotations

from typing import Optional

import numpy as np
import torch
from torch import Tensor, nn

from spflow.exceptions import InvalidWeightsError, MissingCacheError, ScopeError, ShapeError
from spflow.meta.data import Scope
from spflow.modules.module import Module
from spflow.modules.module_shape import ModuleShape
from spflow.modules.ops.split import Split, SplitMode
from spflow.modules.ops.split_consecutive import SplitConsecutive
from spflow.utils.cache import Cache, cached
from spflow.utils.projections import proj_convex_to_real
from spflow.utils.sampling_context import SamplingContext, init_default_sampling_context


[docs] class EinsumLayer(Module): """EinsumLayer combining product and sum operations efficiently. Implements sum(product(x)) using einsum for circuits with arbitrary tree structure. Takes pairs of adjacent features as left/right children, computes their cross-product over channels, and sums with learned weights. The LogEinsumExp trick is used for numerical stability in log-space. Attributes: logits (Parameter): Unnormalized log-weights for gradient optimization. unraveled_channel_indices (Tensor): Mapping from flat to (i,j) channel pairs. """
[docs] def __init__( self, inputs: Module | list[Module], out_channels: int, num_repetitions: int | None = None, weights: Tensor | None = None, split_mode: SplitMode | None = None, ) -> None: """Initialize EinsumLayer. Args: inputs: Either a single module (features will be split into pairs) or a list of exactly two modules (left and right children). out_channels: Number of output sum nodes per feature. num_repetitions: Number of repetitions. If None, inferred from inputs. weights: Optional initial weights tensor. If provided, must have shape (out_features, out_channels, num_repetitions, left_channels, right_channels). split_mode: Optional split configuration for single input mode. Use SplitMode.consecutive() or SplitMode.interleaved(). Defaults to SplitMode.consecutive(num_splits=2) if not specified. Raises: ValueError: If inputs invalid, out_channels < 1, or weight shape mismatch. """ super().__init__() # ========== 1. INPUT VALIDATION ========== if isinstance(inputs, list): if len(inputs) != 2: raise ValueError( f"EinsumLayer requires exactly 2 input modules when given a list, got {len(inputs)}." ) self._two_inputs = True left_input, right_input = inputs # Validate compatible shapes (channels can differ for cross-product) if left_input.out_shape.features != right_input.out_shape.features: raise ValueError( f"Left and right inputs must have same number of features: " f"{left_input.out_shape.features} != {right_input.out_shape.features}" ) if left_input.out_shape.repetitions != right_input.out_shape.repetitions: raise ValueError( f"Left and right inputs must have same number of repetitions: " f"{left_input.out_shape.repetitions} != {right_input.out_shape.repetitions}" ) # Validate disjoint scopes if not Scope.all_pairwise_disjoint([left_input.scope, right_input.scope]): raise ScopeError("Left and right input scopes must be disjoint.") self.inputs = nn.ModuleList([left_input, right_input]) self._left_channels = left_input.out_shape.channels self._right_channels = right_input.out_shape.channels in_features = left_input.out_shape.features if num_repetitions is None: num_repetitions = left_input.out_shape.repetitions self.scope = Scope.join_all([left_input.scope, right_input.scope]) else: # Single input: will split features into left/right halves self._two_inputs = False if inputs.out_shape.features < 2: raise ValueError( f"EinsumLayer requires at least 2 input features for splitting, " f"got {inputs.out_shape.features}." ) if inputs.out_shape.features % 2 != 0: raise ValueError( f"EinsumLayer requires even number of input features for splitting, " f"got {inputs.out_shape.features}." ) # Use Split directly if already a split module, otherwise create from split_mode if isinstance(inputs, Split): self.inputs = inputs elif split_mode is not None: self.inputs = split_mode.create(inputs) else: # Default: consecutive split with 2 parts self.inputs = SplitConsecutive(inputs) self._left_channels = inputs.out_shape.channels self._right_channels = inputs.out_shape.channels in_features = inputs.out_shape.features // 2 if num_repetitions is None: num_repetitions = inputs.out_shape.repetitions self.scope = inputs.scope # ========== 2. CONFIGURATION VALIDATION ========== if out_channels < 1: raise ValueError(f"out_channels must be >= 1, got {out_channels}.") # ========== 3. SHAPE COMPUTATION ========== # Use max channels for in_shape (for informational purposes) max_in_channels = max(self._left_channels, self._right_channels) self.in_shape = ModuleShape(in_features, max_in_channels, num_repetitions) self.out_shape = ModuleShape(in_features, out_channels, num_repetitions) # ========== 4. WEIGHT INITIALIZATION ========== self.weights_shape = ( self.out_shape.features, # D_out self.out_shape.channels, # O (output channels) self.out_shape.repetitions, # R self._left_channels, # I (left input channels) self._right_channels, # J (right input channels) ) # Create index mapping for sampling: flatten (i,j) -> idx and back self.register_buffer( "unraveled_channel_indices", torch.tensor( [(i, j) for i in range(self._left_channels) for j in range(self._right_channels)], dtype=torch.long, ), ) if weights is None: # Initialize weights randomly, normalized over (i,j) pairs weights = torch.rand(self.weights_shape) + 1e-08 weights = weights / weights.sum(dim=(-2, -1), keepdim=True) # Validate weights shape if weights.shape != self.weights_shape: raise ValueError(f"Weight shape mismatch: expected {self.weights_shape}, got {weights.shape}") # Register logits parameter self.logits = nn.Parameter(torch.zeros(self.weights_shape)) # Set weights via property (converts to logits) self.weights = weights
@property def feature_to_scope(self) -> np.ndarray: """Mapping from output features to their scopes.""" if self._two_inputs: # Combine scopes from left and right inputs left_scopes = self.inputs[0].feature_to_scope right_scopes = self.inputs[1].feature_to_scope combined = [] for r in range(self.out_shape.repetitions): rep_scopes = [] for f in range(self.out_shape.features): left_s = left_scopes[f, r] right_s = right_scopes[f, r] rep_scopes.append(Scope.join_all([left_s, right_s])) combined.append(np.array(rep_scopes)) return np.stack(combined, axis=1) else: # Single input split into halves - combine adjacent pairs input_scopes = self.inputs.inputs.feature_to_scope combined = [] for r in range(self.out_shape.repetitions): rep_scopes = [] for f in range(self.out_shape.features): left_s = input_scopes[2 * f, r] right_s = input_scopes[2 * f + 1, r] rep_scopes.append(Scope.join_all([left_s, right_s])) combined.append(np.array(rep_scopes)) return np.stack(combined, axis=1) @property def log_weights(self) -> Tensor: """Log-normalized weights (sum to 1 over input channel pairs).""" # Flatten last two dims, softmax, reshape back flat_logits = self.logits.view(*self.logits.shape[:-2], -1) log_weights = torch.nn.functional.log_softmax(flat_logits, dim=-1) return log_weights.view(self.weights_shape) @property def weights(self) -> Tensor: """Normalized weights (sum to 1 over input channel pairs).""" flat_logits = self.logits.view(*self.logits.shape[:-2], -1) weights = torch.nn.functional.softmax(flat_logits, dim=-1) return weights.view(self.weights_shape) @weights.setter def weights(self, values: Tensor) -> None: """Set weights (must be positive and sum to 1 over i,j pairs).""" if values.shape != self.weights_shape: raise ShapeError(f"Weight shape mismatch: expected {self.weights_shape}, got {values.shape}") if not torch.all(values > 0): raise InvalidWeightsError("Weights must be positive.") sums = values.sum(dim=(-2, -1)) if not torch.allclose(sums, torch.ones_like(sums)): raise InvalidWeightsError("Weights must sum to 1 over (i,j) channel pairs.") # Project to logits space flat_weights = values.view(*values.shape[:-2], -1) flat_logits = proj_convex_to_real(flat_weights) self.logits.data = flat_logits.view(self.weights_shape) @log_weights.setter def log_weights(self, values: Tensor) -> None: """Set log weights directly.""" if values.shape != self.weights_shape: raise ShapeError(f"Log weight shape mismatch: expected {self.weights_shape}, got {values.shape}") self.logits.data = values def extra_repr(self) -> str: return f"{super().extra_repr()}, weights={self.weights_shape}" def _get_left_right_ll(self, data: Tensor, cache: Cache | None = None) -> tuple[Tensor, Tensor]: """Get log-likelihoods from left and right children. Returns: Tuple of (left_ll, right_ll), each of shape (batch, features, channels, reps). """ if self._two_inputs: left_ll = self.inputs[0].log_likelihood(data, cache=cache) right_ll = self.inputs[1].log_likelihood(data, cache=cache) else: # SplitConsecutive returns list of [left, right] lls = self.inputs.log_likelihood(data, cache=cache) left_ll = lls[0] right_ll = lls[1] return left_ll, right_ll
[docs] @cached def log_likelihood( self, data: Tensor, cache: Cache | None = None, ) -> Tensor: """Compute log-likelihood using LogEinsumExp trick. Args: data: Input data of shape (batch_size, num_features). cache: Optional cache for intermediate results. Returns: Log-likelihood tensor of shape (batch, out_features, out_channels, reps). """ # Get child log-likelihoods left_ll, right_ll = self._get_left_right_ll(data, cache) # Dimensions: N=batch, D=features, C=channels, R=reps N, D, C, R = left_ll.size() # LogEinsumExp trick for numerical stability # Compute max for normalization left_max = torch.max(left_ll, dim=2, keepdim=True)[0] # (N, D, 1, R) left_prob = torch.exp(left_ll - left_max) # (N, D, C, R) right_max = torch.max(right_ll, dim=2, keepdim=True)[0] # (N, D, 1, R) right_prob = torch.exp(right_ll - right_max) # (N, D, C, R) # Get normalized weights weights = self.weights # (D, O, R, I, J) # Einsum: product over channels, weighted sum # n=batch, d=features, i=left_channels, j=right_channels, o=out_channels, r=reps prob = torch.einsum("ndir,ndjr,dorij->ndor", left_prob, right_prob, weights) # Re-add the log maxes log_prob = torch.log(prob) + left_max + right_max return log_prob
[docs] def sample( self, num_samples: int | None = None, data: Tensor | None = None, is_mpe: bool = False, cache: Cache | None = None, sampling_ctx: SamplingContext | None = None, ) -> Tensor: """Sample from the EinsumLayer. Args: num_samples: Number of samples to generate. data: Optional data tensor with evidence (NaN for missing). is_mpe: Whether to perform MPE instead of sampling. cache: Optional cache with log-likelihoods for conditional sampling. sampling_ctx: Sampling context with channel indices. Returns: Sampled data tensor. """ # Prepare data tensor data = self._prepare_sample_data(num_samples, data) if cache is None: cache = Cache() sampling_ctx = init_default_sampling_context(sampling_ctx, data.shape[0], data.device) # Get logits and select based on context logits = self.logits # (D, O, R, I, J) # Expand for batch dimension batch_size = sampling_ctx.channel_index.shape[0] logits = logits.unsqueeze(0).expand(batch_size, -1, -1, -1, -1, -1) # logits shape: (B, D, O, R, I, J) # Select output channel based on parent's channel_index # sampling_ctx.channel_index: (B, D) - indices into out_channels channel_idx = sampling_ctx.channel_index # (B, D) # Gather the correct output channel # Expand channel_idx to match logits dimensions idx = channel_idx.view(batch_size, self.out_shape.features, 1, 1, 1, 1) idx = idx.expand(-1, -1, -1, self.out_shape.repetitions, self._left_channels, self._right_channels) logits = logits.gather(dim=2, index=idx).squeeze(2) # logits shape: (B, D, R, I, J) # Select repetition if specified if sampling_ctx.repetition_idx is not None: rep_idx = sampling_ctx.repetition_idx.view(-1, 1, 1, 1, 1) rep_idx = rep_idx.expand( -1, self.out_shape.features, -1, self._left_channels, self._right_channels ) logits = logits.gather(dim=2, index=rep_idx).squeeze(2) # logits shape: (B, D, I, J) else: if self.out_shape.repetitions > 1: raise ValueError("repetition_idx must be provided when sampling with num_repetitions > 1") logits = logits[:, :, 0, :, :] # (B, D, I, J) # Flatten (I, J) for categorical sampling logits_flat = logits.view(batch_size, self.out_shape.features, -1) # (B, D, I*J) # Condition on evidence if cache has log-likelihoods if self._two_inputs: left_cache_key = self.inputs[0] right_cache_key = self.inputs[1] else: left_cache_key = "einsum_left" right_cache_key = "einsum_right" if ( cache is not None and "log_likelihood" in cache and cache["log_likelihood"].get(left_cache_key) is not None and cache["log_likelihood"].get(right_cache_key) is not None ): # Get cached log-likelihoods left_ll = cache["log_likelihood"][left_cache_key] # (B, D, C, R) right_ll = cache["log_likelihood"][right_cache_key] # (B, D, C, R) # Select repetition if sampling_ctx.repetition_idx is not None: rep_idx = sampling_ctx.repetition_idx.view(-1, 1, 1, 1) rep_idx_l = rep_idx.expand(-1, left_ll.shape[1], left_ll.shape[2], -1) left_ll = left_ll.gather(dim=-1, index=rep_idx_l).squeeze(-1) right_ll = right_ll.gather(dim=-1, index=rep_idx_l).squeeze(-1) # Compute joint log-likelihood for each (i, j) pair # left_ll: (B, D, I), right_ll: (B, D, J) left_ll = left_ll.unsqueeze(-1) # (B, D, I, 1) right_ll = right_ll.unsqueeze(-2) # (B, D, 1, J) joint_ll = left_ll + right_ll # (B, D, I, J) joint_ll_flat = joint_ll.view(batch_size, self.out_shape.features, -1) # Compute posterior log_prior = logits_flat log_posterior = log_prior + joint_ll_flat log_posterior = log_posterior - torch.logsumexp(log_posterior, dim=-1, keepdim=True) logits_flat = log_posterior # Sample or MPE if is_mpe: indices = logits_flat.argmax(dim=-1) # (B, D) else: dist = torch.distributions.Categorical(logits=logits_flat) indices = dist.sample() # (B, D) # Unravel indices to (i, j) pairs ij_indices = self.unraveled_channel_indices[indices] # (B, D, 2) left_indices = ij_indices[..., 0] # (B, D) right_indices = ij_indices[..., 1] # (B, D) # Sample from left and right children if self._two_inputs: # Left child left_ctx = sampling_ctx.copy() left_ctx.channel_index = left_indices self.inputs[0].sample(data=data, is_mpe=is_mpe, cache=cache, sampling_ctx=left_ctx) # Right child right_ctx = sampling_ctx.copy() right_ctx.channel_index = right_indices self.inputs[1].sample(data=data, is_mpe=is_mpe, cache=cache, sampling_ctx=right_ctx) else: # Single input with Split module - use generic merge_split_indices full_indices = self.inputs.merge_split_indices(left_indices, right_indices) full_mask = sampling_ctx.mask.repeat(1, 2) child_ctx = sampling_ctx.copy() child_ctx.update(channel_index=full_indices, mask=full_mask) self.inputs.sample(data=data, is_mpe=is_mpe, cache=cache, sampling_ctx=child_ctx) return data
[docs] def expectation_maximization( self, data: Tensor, bias_correction: bool = True, cache: Cache | None = None, ) -> None: """Perform EM step to update weights. Args: data: Training data tensor. bias_correction: Whether to apply bias correction. cache: Cache with log-likelihoods. """ if cache is None: cache = Cache() with torch.no_grad(): # Get cached values left_ll, right_ll = self._get_left_right_ll(data, cache) module_lls = cache["log_likelihood"].get(self) if module_lls is None: raise MissingCacheError("Module log-likelihoods not in cache. Call log_likelihood first.") # E-step: compute expected counts log_weights = self.log_weights.unsqueeze(0) # (1, D, O, R, I, J) # Expand input lls for outer product left_ll = left_ll.unsqueeze(3).unsqueeze(5) # (B, D, I, 1, R, 1) right_ll = right_ll.unsqueeze(2).unsqueeze(5) # (B, D, 1, J, R, 1) # Rearrange dimensions to match weights: (D, O, R, I, J) left_ll = left_ll.permute(0, 1, 4, 2, 3, 5).squeeze(-1) # (B, D, R, I, 1) right_ll = right_ll.permute(0, 1, 4, 2, 3, 5).squeeze(-1) # (B, D, R, 1, J) # Get gradients (how much each output contributed) log_grads = torch.log(module_lls.grad + 1e-10) log_grads = log_grads.unsqueeze(-1).unsqueeze(-1) # (B, D, O, R, 1, 1) module_lls = module_lls.unsqueeze(-1).unsqueeze(-1) # (B, D, O, R, 1, 1) # Joint input log-likelihood joint_input_ll = left_ll.unsqueeze(2) + right_ll.unsqueeze(2) # (B, D, 1, R, I, J) # Compute log expectations log_expectations = log_weights + log_grads + joint_input_ll - module_lls log_expectations = log_expectations.logsumexp(0) # Sum over batch # Normalize to get new log weights flat_expectations = log_expectations.view(*log_expectations.shape[:-2], -1) flat_log_weights = torch.nn.functional.log_softmax(flat_expectations, dim=-1) new_log_weights = flat_log_weights.view(self.weights_shape) # M-step: update weights self.log_weights = new_log_weights # Recurse to children if self._two_inputs: self.inputs[0].expectation_maximization(data, bias_correction=bias_correction, cache=cache) self.inputs[1].expectation_maximization(data, bias_correction=bias_correction, cache=cache) else: self.inputs.inputs.expectation_maximization(data, bias_correction=bias_correction, cache=cache)
[docs] def maximum_likelihood_estimation( self, data: Tensor, weights: Tensor | None = None, bias_correction: bool = True, nan_strategy: str = "ignore", cache: Cache | None = None, ) -> None: """MLE step (equivalent to EM for sum nodes).""" self.expectation_maximization(data, bias_correction=bias_correction, cache=cache)
[docs] def marginalize( self, marg_rvs: list[int], prune: bool = True, cache: Cache | None = None, ) -> Optional["EinsumLayer" | Module]: """Marginalize out specified random variables. Args: marg_rvs: Random variable indices to marginalize. prune: Whether to prune unnecessary modules. cache: Cache for memoization. Returns: Marginalized module or None if fully marginalized. """ if cache is None: cache = Cache() module_scope = self.scope mutual_rvs = set(module_scope.query).intersection(set(marg_rvs)) # Fully marginalized if len(mutual_rvs) == len(module_scope.query): return None # No overlap - return self unchanged (inputs unchanged) if not mutual_rvs: return self # Partially marginalized if self._two_inputs: left_marg = self.inputs[0].marginalize(marg_rvs, prune=prune, cache=cache) right_marg = self.inputs[1].marginalize(marg_rvs, prune=prune, cache=cache) if left_marg is None and right_marg is None: return None elif left_marg is None: return right_marg elif right_marg is None: return left_marg else: # Both still exist - create new EinsumLayer with marginalized children return EinsumLayer( inputs=[left_marg, right_marg], out_channels=self.out_shape.channels, num_repetitions=self.out_shape.repetitions, ) else: # Single input - marginalize the underlying input marg_input = self.inputs.inputs.marginalize(marg_rvs, prune=prune, cache=cache) if marg_input is None: return None # Check if we still have enough features for EinsumLayer if marg_input.out_shape.features < 2: return marg_input if marg_input.out_shape.features % 2 != 0: # Odd number of features - can't use EinsumLayer return marg_input return EinsumLayer( inputs=marg_input, out_channels=self.out_shape.channels, num_repetitions=self.out_shape.repetitions, )