Source code for spflow.modules.einsum.linsum_layer

"""LinsumLayer for efficient linear sum-product operations in probabilistic circuits.

Unlike EinsumLayer which computes a cross-product of input channels,
LinsumLayer computes a linear combination: it adds left/right features
(product in log-space), then applies a weighted sum over input channels.
"""

from __future__ import annotations

from typing import Optional

import numpy as np
import torch
from torch import Tensor, nn

from spflow.exceptions import InvalidWeightsError, MissingCacheError, ScopeError, ShapeError
from spflow.meta.data import Scope
from spflow.modules.module import Module
from spflow.modules.module_shape import ModuleShape
from spflow.modules.ops.split import Split, SplitMode
from spflow.modules.ops.split_consecutive import SplitConsecutive
from spflow.utils.cache import Cache, cached
from spflow.utils.projections import proj_convex_to_real
from spflow.utils.sampling_context import SamplingContext, init_default_sampling_context


[docs] class LinsumLayer(Module): """LinsumLayer combining product and sum operations with linear channel combination. Unlike EinsumLayer which computes cross-product over channels (I × J combinations), LinsumLayer computes a linear combination: pairs left/right features, adds them (product in log-space), then sums over input channels with learned weights. This results in fewer parameters: weight_shape = (D_out, O, R, C) vs EinsumLayer's (D_out, O, R, I, J). Attributes: logits (Parameter): Unnormalized log-weights for gradient optimization. """
[docs] def __init__( self, inputs: Module | list[Module], out_channels: int, num_repetitions: int | None = None, weights: Tensor | None = None, split_mode: SplitMode | None = None, ) -> None: """Initialize LinsumLayer. Args: inputs: Either a single module (features will be split into pairs) or a list of exactly two modules (left and right children). Unlike EinsumLayer, both inputs must have the same number of channels. out_channels: Number of output sum nodes per feature. num_repetitions: Number of repetitions. If None, inferred from inputs. weights: Optional initial weights tensor. If provided, must have shape (out_features, out_channels, num_repetitions, in_channels). split_mode: Optional split configuration for single input mode. Use SplitMode.consecutive() or SplitMode.interleaved(). Defaults to SplitMode.consecutive(num_splits=2) if not specified. Raises: ValueError: If inputs invalid, out_channels < 1, or weight shape mismatch. """ super().__init__() # ========== 1. INPUT VALIDATION ========== if isinstance(inputs, list): if len(inputs) != 2: raise ValueError( f"LinsumLayer requires exactly 2 input modules when given a list, got {len(inputs)}." ) self._two_inputs = True left_input, right_input = inputs # LinsumLayer requires same number of channels (linear combination, not cross-product) if left_input.out_shape.channels != right_input.out_shape.channels: raise ValueError( f"LinsumLayer requires left and right inputs to have same number of channels: " f"{left_input.out_shape.channels} != {right_input.out_shape.channels}" ) if left_input.out_shape.features != right_input.out_shape.features: raise ValueError( f"Left and right inputs must have same number of features: " f"{left_input.out_shape.features} != {right_input.out_shape.features}" ) if left_input.out_shape.repetitions != right_input.out_shape.repetitions: raise ValueError( f"Left and right inputs must have same number of repetitions: " f"{left_input.out_shape.repetitions} != {right_input.out_shape.repetitions}" ) # Validate disjoint scopes if not Scope.all_pairwise_disjoint([left_input.scope, right_input.scope]): raise ScopeError("Left and right input scopes must be disjoint.") self.inputs = nn.ModuleList([left_input, right_input]) in_channels = left_input.out_shape.channels in_features = left_input.out_shape.features if num_repetitions is None: num_repetitions = left_input.out_shape.repetitions self.scope = Scope.join_all([left_input.scope, right_input.scope]) else: # Single input: will split features into left/right halves self._two_inputs = False if inputs.out_shape.features < 2: raise ValueError( f"LinsumLayer requires at least 2 input features for splitting, " f"got {inputs.out_shape.features}." ) if inputs.out_shape.features % 2 != 0: raise ValueError( f"LinsumLayer requires even number of input features for splitting, " f"got {inputs.out_shape.features}." ) # Use Split directly if already a split module, otherwise create from split_mode if isinstance(inputs, Split): self.inputs = inputs elif split_mode is not None: self.inputs = split_mode.create(inputs) else: # Default: consecutive split with 2 parts self.inputs = SplitConsecutive(inputs) in_channels = inputs.out_shape.channels in_features = inputs.out_shape.features // 2 if num_repetitions is None: num_repetitions = inputs.out_shape.repetitions self.scope = inputs.scope # ========== 2. CONFIGURATION VALIDATION ========== if out_channels < 1: raise ValueError(f"out_channels must be >= 1, got {out_channels}.") # ========== 3. SHAPE COMPUTATION ========== self._in_channels = in_channels self.in_shape = ModuleShape(in_features, in_channels, num_repetitions) self.out_shape = ModuleShape(in_features, out_channels, num_repetitions) # ========== 4. WEIGHT INITIALIZATION ========== # Linear sum: weight over input channels only (not cross-product) self.weights_shape = ( self.out_shape.features, # D_out self.out_shape.channels, # O (output channels) self.out_shape.repetitions, # R self._in_channels, # C (input channels - linear, not cross-product) ) if weights is None: # Initialize weights randomly, normalized over input channels weights = torch.rand(self.weights_shape) + 1e-08 weights = weights / weights.sum(dim=-1, keepdim=True) # Validate weights shape if weights.shape != self.weights_shape: raise ValueError(f"Weight shape mismatch: expected {self.weights_shape}, got {weights.shape}") # Register logits parameter self.logits = nn.Parameter(torch.zeros(self.weights_shape)) # Set weights via property (converts to logits) self.weights = weights
@property def feature_to_scope(self) -> np.ndarray: """Mapping from output features to their scopes.""" if self._two_inputs: # Combine scopes from left and right inputs left_scopes = self.inputs[0].feature_to_scope right_scopes = self.inputs[1].feature_to_scope combined = [] for r in range(self.out_shape.repetitions): rep_scopes = [] for f in range(self.out_shape.features): left_s = left_scopes[f, r] right_s = right_scopes[f, r] rep_scopes.append(Scope.join_all([left_s, right_s])) combined.append(np.array(rep_scopes)) return np.stack(combined, axis=1) else: # Single input split into halves - combine adjacent pairs input_scopes = self.inputs.inputs.feature_to_scope combined = [] for r in range(self.out_shape.repetitions): rep_scopes = [] for f in range(self.out_shape.features): left_s = input_scopes[2 * f, r] right_s = input_scopes[2 * f + 1, r] rep_scopes.append(Scope.join_all([left_s, right_s])) combined.append(np.array(rep_scopes)) return np.stack(combined, axis=1) @property def log_weights(self) -> Tensor: """Log-normalized weights (sum to 1 over input channels).""" return torch.nn.functional.log_softmax(self.logits, dim=-1) @property def weights(self) -> Tensor: """Normalized weights (sum to 1 over input channels).""" return torch.nn.functional.softmax(self.logits, dim=-1) @weights.setter def weights(self, values: Tensor) -> None: """Set weights (must be positive and sum to 1 over channels).""" if values.shape != self.weights_shape: raise ShapeError(f"Weight shape mismatch: expected {self.weights_shape}, got {values.shape}") if not torch.all(values > 0): raise InvalidWeightsError("Weights must be positive.") sums = values.sum(dim=-1) if not torch.allclose(sums, torch.ones_like(sums)): raise InvalidWeightsError("Weights must sum to 1 over input channels.") # Project to logits space self.logits.data = proj_convex_to_real(values) @log_weights.setter def log_weights(self, values: Tensor) -> None: """Set log weights directly.""" if values.shape != self.weights_shape: raise ShapeError(f"Log weight shape mismatch: expected {self.weights_shape}, got {values.shape}") self.logits.data = values def extra_repr(self) -> str: return f"{super().extra_repr()}, weights={self.weights_shape}" def _get_left_right_ll(self, data: Tensor, cache: Cache | None = None) -> tuple[Tensor, Tensor]: """Get log-likelihoods from left and right children. Returns: Tuple of (left_ll, right_ll), each of shape (batch, features, channels, reps). """ if self._two_inputs: left_ll = self.inputs[0].log_likelihood(data, cache=cache) right_ll = self.inputs[1].log_likelihood(data, cache=cache) else: # SplitConsecutive returns list of [left, right] lls = self.inputs.log_likelihood(data, cache=cache) left_ll = lls[0] right_ll = lls[1] return left_ll, right_ll
[docs] @cached def log_likelihood( self, data: Tensor, cache: Cache | None = None, ) -> Tensor: """Compute log-likelihood using linear sum over channels. Unlike EinsumLayer which computes cross-product (I × J), this computes a linear combination: add left+right (product), then logsumexp over channels. Args: data: Input data of shape (batch_size, num_features). cache: Optional cache for intermediate results. Returns: Log-likelihood tensor of shape (batch, out_features, out_channels, reps). """ # Get child log-likelihoods left_ll, right_ll = self._get_left_right_ll(data, cache) # Dimensions: N=batch, D=features, C=channels, R=reps N, D, C, R = left_ll.size() # Product: left + right in log-space # Shape: (N, D, C, R) prod_ll = left_ll + right_ll # Expand for output channels dimension # prod_ll: (N, D, C, R) -> (N, D, 1, C, R) prod_ll = prod_ll.unsqueeze(2) # Get log weights: (D, O, R, C) -> (1, D, O, C, R) log_weights = self.log_weights.permute(0, 1, 3, 2).unsqueeze(0) # Weighted sum over input channels # (N, D, 1, C, R) + (1, D, O, C, R) -> (N, D, O, C, R) weighted_ll = prod_ll + log_weights # LogSumExp over input channels (dim=3) log_prob = torch.logsumexp(weighted_ll, dim=3) # (N, D, O, R) return log_prob
[docs] def sample( self, num_samples: int | None = None, data: Tensor | None = None, is_mpe: bool = False, cache: Cache | None = None, sampling_ctx: SamplingContext | None = None, ) -> Tensor: """Sample from the LinsumLayer. Args: num_samples: Number of samples to generate. data: Optional data tensor with evidence (NaN for missing). is_mpe: Whether to perform MPE instead of sampling. cache: Optional cache with log-likelihoods for conditional sampling. sampling_ctx: Sampling context with channel indices. Returns: Sampled data tensor. """ # Prepare data tensor data = self._prepare_sample_data(num_samples, data) if cache is None: cache = Cache() sampling_ctx = init_default_sampling_context(sampling_ctx, data.shape[0], data.device) # Get logits and select based on context logits = self.logits # (D, O, R, C) # Expand for batch dimension batch_size = sampling_ctx.channel_index.shape[0] logits = logits.unsqueeze(0).expand(batch_size, -1, -1, -1, -1) # logits shape: (B, D, O, R, C) # Select output channel based on parent's channel_index channel_idx = sampling_ctx.channel_index # (B, D) # Gather the correct output channel idx = channel_idx.view(batch_size, self.out_shape.features, 1, 1, 1) idx = idx.expand(-1, -1, -1, self.out_shape.repetitions, self._in_channels) logits = logits.gather(dim=2, index=idx).squeeze(2) # logits shape: (B, D, R, C) # Select repetition if specified if sampling_ctx.repetition_idx is not None: rep_idx = sampling_ctx.repetition_idx.view(-1, 1, 1, 1) rep_idx = rep_idx.expand(-1, self.out_shape.features, -1, self._in_channels) logits = logits.gather(dim=2, index=rep_idx).squeeze(2) # logits shape: (B, D, C) else: if self.out_shape.repetitions > 1: raise ValueError("repetition_idx must be provided when sampling with num_repetitions > 1") logits = logits[:, :, 0, :] # (B, D, C) # Condition on evidence if cache has log-likelihoods if self._two_inputs: left_cache_key = self.inputs[0] right_cache_key = self.inputs[1] else: left_cache_key = "linsum_left" right_cache_key = "linsum_right" if ( cache is not None and "log_likelihood" in cache and cache["log_likelihood"].get(left_cache_key) is not None and cache["log_likelihood"].get(right_cache_key) is not None ): # Get cached log-likelihoods left_ll = cache["log_likelihood"][left_cache_key] # (B, D, C, R) right_ll = cache["log_likelihood"][right_cache_key] # (B, D, C, R) # Select repetition if sampling_ctx.repetition_idx is not None: rep_idx = sampling_ctx.repetition_idx.view(-1, 1, 1, 1) rep_idx_l = rep_idx.expand(-1, left_ll.shape[1], left_ll.shape[2], -1) left_ll = left_ll.gather(dim=-1, index=rep_idx_l).squeeze(-1) right_ll = right_ll.gather(dim=-1, index=rep_idx_l).squeeze(-1) # Product log-likelihood for each channel prod_ll = left_ll + right_ll # (B, D, C) # Compute posterior log_prior = logits log_posterior = log_prior + prod_ll log_posterior = log_posterior - torch.logsumexp(log_posterior, dim=-1, keepdim=True) logits = log_posterior # Sample or MPE if is_mpe: indices = logits.argmax(dim=-1) # (B, D) else: dist = torch.distributions.Categorical(logits=logits) indices = dist.sample() # (B, D) # Sample from left and right children with same channel index # (Linear combination means left and right use the same channel) if self._two_inputs: # Left child left_ctx = sampling_ctx.copy() left_ctx.channel_index = indices self.inputs[0].sample(data=data, is_mpe=is_mpe, cache=cache, sampling_ctx=left_ctx) # Right child right_ctx = sampling_ctx.copy() right_ctx.channel_index = indices self.inputs[1].sample(data=data, is_mpe=is_mpe, cache=cache, sampling_ctx=right_ctx) else: # Single input with Split module - use generic merge_split_indices # For LinsumLayer, both left and right use the same indices (linear combination) full_indices = self.inputs.merge_split_indices(indices, indices) full_mask = sampling_ctx.mask.repeat(1, 2) child_ctx = sampling_ctx.copy() child_ctx.update(channel_index=full_indices, mask=full_mask) self.inputs.sample(data=data, is_mpe=is_mpe, cache=cache, sampling_ctx=child_ctx) return data
[docs] def expectation_maximization( self, data: Tensor, bias_correction: bool = True, cache: Cache | None = None, ) -> None: """Perform EM step to update weights. Args: data: Training data tensor. bias_correction: Whether to apply bias correction. cache: Cache with log-likelihoods. """ if cache is None: cache = Cache() with torch.no_grad(): # Get cached values left_ll, right_ll = self._get_left_right_ll(data, cache) module_lls = cache["log_likelihood"].get(self) if module_lls is None: raise MissingCacheError("Module log-likelihoods not in cache. Call log_likelihood first.") # E-step: compute expected counts log_weights = self.log_weights.unsqueeze(0) # (1, D, O, R, C) # Product of left and right prod_ll = left_ll + right_ll # (B, D, C, R) # Rearrange to match weights: (B, D, O, R, C) prod_ll = prod_ll.permute(0, 1, 3, 2).unsqueeze(2) # (B, D, 1, R, C) # Get gradients log_grads = torch.log(module_lls.grad + 1e-10) log_grads = log_grads.unsqueeze(-1) # (B, D, O, R, 1) module_lls = module_lls.unsqueeze(-1) # (B, D, O, R, 1) # Compute log expectations log_expectations = log_weights + log_grads + prod_ll - module_lls log_expectations = log_expectations.logsumexp(0) # Sum over batch # Normalize to get new log weights new_log_weights = torch.nn.functional.log_softmax(log_expectations, dim=-1) # M-step: update weights self.log_weights = new_log_weights # Recurse to children if self._two_inputs: self.inputs[0].expectation_maximization(data, bias_correction=bias_correction, cache=cache) self.inputs[1].expectation_maximization(data, bias_correction=bias_correction, cache=cache) else: self.inputs.inputs.expectation_maximization(data, bias_correction=bias_correction, cache=cache)
[docs] def maximum_likelihood_estimation( self, data: Tensor, weights: Tensor | None = None, bias_correction: bool = True, nan_strategy: str = "ignore", cache: Cache | None = None, ) -> None: """MLE step (equivalent to EM for sum nodes).""" self.expectation_maximization(data, bias_correction=bias_correction, cache=cache)
[docs] def marginalize( self, marg_rvs: list[int], prune: bool = True, cache: Cache | None = None, ) -> Optional["LinsumLayer" | Module]: """Marginalize out specified random variables. Args: marg_rvs: Random variable indices to marginalize. prune: Whether to prune unnecessary modules. cache: Cache for memoization. Returns: Marginalized module or None if fully marginalized. """ if cache is None: cache = Cache() module_scope = self.scope mutual_rvs = set(module_scope.query).intersection(set(marg_rvs)) # Fully marginalized if len(mutual_rvs) == len(module_scope.query): return None # No overlap - return self unchanged if not mutual_rvs: return self # Partially marginalized if self._two_inputs: left_marg = self.inputs[0].marginalize(marg_rvs, prune=prune, cache=cache) right_marg = self.inputs[1].marginalize(marg_rvs, prune=prune, cache=cache) if left_marg is None and right_marg is None: return None elif left_marg is None: return right_marg elif right_marg is None: return left_marg else: # Both still exist - create new LinsumLayer with marginalized children return LinsumLayer( inputs=[left_marg, right_marg], out_channels=self.out_shape.channels, num_repetitions=self.out_shape.repetitions, ) else: # Single input - marginalize the underlying input marg_input = self.inputs.inputs.marginalize(marg_rvs, prune=prune, cache=cache) if marg_input is None: return None # Check if we still have enough features for LinsumLayer if marg_input.out_shape.features < 2: return marg_input if marg_input.out_shape.features % 2 != 0: # Odd number of features - can't use LinsumLayer return marg_input return LinsumLayer( inputs=marg_input, out_channels=self.out_shape.channels, num_repetitions=self.out_shape.repetitions, )