Source code for spflow.modules.einsum.linsum_layer

"""LinsumLayer for efficient linear sum-product operations in probabilistic circuits.

Unlike EinsumLayer which computes a cross-product of input channels,
LinsumLayer computes a linear combination: it adds left/right features
(product in log-space), then applies a weighted sum over input channels.
"""

from __future__ import annotations

from typing import Optional

import numpy as np
import torch
from einops import rearrange, repeat
from torch import Tensor, nn

from spflow.exceptions import (
    InvalidWeightsError,
    MissingCacheError,
    ScopeError,
    ShapeError,
)
from spflow.meta.data import Scope
from spflow.modules.module import Module
from spflow.modules.module_shape import ModuleShape
from spflow.modules.ops.split import Split, SplitMode
from spflow.modules.ops.split_consecutive import SplitConsecutive
from spflow.utils.cache import Cache, cached
from spflow.utils.projections import proj_convex_to_real
from spflow.utils.sampling_context import (
    SamplingContext,
    index_tensor,
    repeat_channel_index,
    repeat_repetition_index,
    sample_from_logits,
)


[docs] class LinsumLayer(Module): """LinsumLayer combining product and sum operations with linear channel combination. Unlike EinsumLayer which computes cross-product over channels (I × J combinations), LinsumLayer computes a linear combination: pairs left/right features, adds them (product in log-space), then sums over input channels with learned weights. This results in fewer parameters: weight_shape = (D_out, O, R, C) vs EinsumLayer's (D_out, O, R, I, J). Attributes: logits (Parameter): Unnormalized log-weights for gradient optimization. """
[docs] def __init__( self, inputs: Module | list[Module], out_channels: int, num_repetitions: int | None = None, weights: Tensor | None = None, split_mode: SplitMode | None = None, ) -> None: """Initialize LinsumLayer. Args: inputs: Either a single module (features will be split into pairs) or a list of exactly two modules (left and right children). Unlike EinsumLayer, both inputs must have the same number of channels. out_channels: Number of output sum nodes per feature. num_repetitions: Number of repetitions. If None, inferred from inputs. weights: Optional initial weights tensor. If provided, must have shape (out_features, out_channels, num_repetitions, in_channels). split_mode: Optional split configuration for single input mode. Use SplitMode.consecutive() or SplitMode.interleaved(). Defaults to SplitMode.consecutive(num_splits=2) if not specified. Raises: ValueError: If inputs invalid, out_channels < 1, or weight shape mismatch. """ super().__init__() # ========== 1. INPUT VALIDATION ========== if isinstance(inputs, list): if len(inputs) != 2: raise ValueError( f"LinsumLayer requires exactly 2 input modules when given a list, got {len(inputs)}." ) self._two_inputs = True left_input, right_input = inputs # LinsumLayer requires same number of channels (linear combination, not cross-product) if left_input.out_shape.channels != right_input.out_shape.channels: raise ValueError( f"LinsumLayer requires left and right inputs to have same number of channels: " f"{left_input.out_shape.channels} != {right_input.out_shape.channels}" ) if left_input.out_shape.features != right_input.out_shape.features: raise ValueError( f"Left and right inputs must have same number of features: " f"{left_input.out_shape.features} != {right_input.out_shape.features}" ) if left_input.out_shape.repetitions != right_input.out_shape.repetitions: raise ValueError( f"Left and right inputs must have same number of repetitions: " f"{left_input.out_shape.repetitions} != {right_input.out_shape.repetitions}" ) # Validate disjoint scopes if not Scope.all_pairwise_disjoint([left_input.scope, right_input.scope]): raise ScopeError("Left and right input scopes must be disjoint.") self.inputs = nn.ModuleList([left_input, right_input]) in_channels = left_input.out_shape.channels in_features = left_input.out_shape.features if num_repetitions is None: num_repetitions = left_input.out_shape.repetitions self.scope = Scope.join_all([left_input.scope, right_input.scope]) else: # Single input: will split features into left/right halves self._two_inputs = False if inputs.out_shape.features < 2: raise ValueError( f"LinsumLayer requires at least 2 input features for splitting, " f"got {inputs.out_shape.features}." ) if inputs.out_shape.features % 2 != 0: raise ValueError( f"LinsumLayer requires even number of input features for splitting, " f"got {inputs.out_shape.features}." ) # Use Split directly if already a split module, otherwise create from split_mode if isinstance(inputs, Split): self.inputs = inputs elif split_mode is not None: self.inputs = split_mode.create(inputs) else: # Default: consecutive split with 2 parts self.inputs = SplitConsecutive(inputs) in_channels = inputs.out_shape.channels in_features = inputs.out_shape.features // 2 if num_repetitions is None: num_repetitions = inputs.out_shape.repetitions self.scope = inputs.scope # ========== 2. CONFIGURATION VALIDATION ========== if out_channels < 1: raise ValueError(f"out_channels must be >= 1, got {out_channels}.") # ========== 3. SHAPE COMPUTATION ========== self._in_channels = in_channels self.in_shape = ModuleShape(in_features, in_channels, num_repetitions) self.out_shape = ModuleShape(in_features, out_channels, num_repetitions) # ========== 4. WEIGHT INITIALIZATION ========== # Linear sum: weight over input channels only (not cross-product) self.weights_shape = ( self.out_shape.features, # D_out self.out_shape.channels, # O (output channels) self.out_shape.repetitions, # R self._in_channels, # C (input channels - linear, not cross-product) ) if weights is None: # Initialize weights randomly, normalized over input channels weights = torch.rand(self.weights_shape) + 1e-08 weights = weights / weights.sum(dim=-1, keepdim=True) # Validate weights shape if weights.shape != self.weights_shape: raise ValueError(f"Weight shape mismatch: expected {self.weights_shape}, got {weights.shape}") # Register logits parameter self.logits = nn.Parameter(torch.zeros(self.weights_shape)) # Set weights via property (converts to logits) self.weights = weights
@property def feature_to_scope(self) -> np.ndarray: """Mapping from output features to their scopes.""" if self._two_inputs: # Combine scopes from left and right inputs left_scopes = self.inputs[0].feature_to_scope right_scopes = self.inputs[1].feature_to_scope combined = [] for r in range(self.out_shape.repetitions): rep_scopes = [] for f in range(self.out_shape.features): left_s = left_scopes[f, r] right_s = right_scopes[f, r] rep_scopes.append(Scope.join_all([left_s, right_s])) combined.append(np.array(rep_scopes)) return np.stack(combined, axis=1) else: # Single input split into halves - combine adjacent pairs input_scopes = self.inputs.inputs.feature_to_scope combined = [] for r in range(self.out_shape.repetitions): rep_scopes = [] for f in range(self.out_shape.features): left_s = input_scopes[2 * f, r] right_s = input_scopes[2 * f + 1, r] rep_scopes.append(Scope.join_all([left_s, right_s])) combined.append(np.array(rep_scopes)) return np.stack(combined, axis=1) @property def log_weights(self) -> Tensor: """Log-normalized weights (sum to 1 over input channels).""" return torch.nn.functional.log_softmax(self.logits, dim=-1) @property def weights(self) -> Tensor: """Normalized weights (sum to 1 over input channels).""" return torch.nn.functional.softmax(self.logits, dim=-1) @weights.setter def weights(self, values: Tensor) -> None: """Set weights (must be positive and sum to 1 over channels).""" if values.shape != self.weights_shape: raise ShapeError(f"Weight shape mismatch: expected {self.weights_shape}, got {values.shape}") if not torch.all(values > 0): raise InvalidWeightsError("Weights must be positive.") sums = values.sum(dim=-1) if not torch.allclose(sums, torch.ones_like(sums)): raise InvalidWeightsError("Weights must sum to 1 over input channels.") # Project to logits space self.logits.data = proj_convex_to_real(values) @log_weights.setter def log_weights(self, values: Tensor) -> None: """Set log weights directly.""" if values.shape != self.weights_shape: raise ShapeError(f"Log weight shape mismatch: expected {self.weights_shape}, got {values.shape}") self.logits.data = values def extra_repr(self) -> str: return f"{super().extra_repr()}, weights={self.weights_shape}" def _get_left_right_ll(self, data: Tensor, cache: Cache | None = None) -> tuple[Tensor, Tensor]: """Get log-likelihoods from left and right children. Returns: Tuple of (left_ll, right_ll), each of shape (batch, features, channels, reps). """ if self._two_inputs: left_ll = self.inputs[0].log_likelihood(data, cache=cache) right_ll = self.inputs[1].log_likelihood(data, cache=cache) else: # SplitConsecutive returns list of [left, right] lls = self.inputs.log_likelihood(data, cache=cache) left_ll = lls[0] right_ll = lls[1] return left_ll, right_ll
[docs] @cached def log_likelihood( self, data: Tensor, cache: Cache | None = None, ) -> Tensor: """Compute log-likelihood using linear sum over channels. Unlike EinsumLayer which computes cross-product (I × J), this computes a linear combination: add left+right (product), then logsumexp over channels. Args: data: Input data of shape (batch_size, num_features). cache: Optional cache for intermediate results. Returns: Log-likelihood tensor of shape (batch, out_features, out_channels, reps). """ # Get child log-likelihoods left_ll, right_ll = self._get_left_right_ll(data, cache) # Dimensions: N=batch, D=features, C=channels, R=reps N, D, C, R = left_ll.size() # Product: left + right in log-space # Shape: (N, D, C, R) prod_ll = left_ll + right_ll # Expand for output channels dimension # prod_ll: (N, D, C, R) -> (N, D, 1, C, R) prod_ll = rearrange(prod_ll, "n f ci r -> n f 1 ci r") # Get log weights: (D, O, R, C) -> (1, D, O, C, R) log_weights = rearrange(self.log_weights, "f co r ci -> 1 f co ci r") # Weighted sum over input channels # (N, D, 1, C, R) + (1, D, O, C, R) -> (N, D, O, C, R) weighted_ll = prod_ll + log_weights # LogSumExp over input channels (dim=3) log_prob = torch.logsumexp(weighted_ll, dim=3) # (N, D, O, R) return log_prob
def _sample( self, data: Tensor, sampling_ctx: SamplingContext, cache: Cache, ) -> Tensor: """Sample from the LinsumLayer. Args: num_samples: Number of samples to generate. data: Optional data tensor with evidence (NaN for missing). is_mpe: Whether to perform MPE instead of sampling. cache: Optional cache with log-likelihoods for conditional sampling. sampling_ctx: Sampling context with channel indices. Returns: Sampled data tensor. """ # Prepare data tensor sampling_ctx.validate_sampling_context( num_samples=data.shape[0], num_features=self.out_shape.features, num_channels=self.out_shape.channels, num_repetitions=self.out_shape.repetitions, allowed_feature_widths=(1, self.out_shape.features), ) sampling_ctx.broadcast_feature_width(target_features=self.out_shape.features, allow_from_one=True) # Get logits and select based on context logits = self.logits # (D, O, R, C) # Expand for batch dimension batch_size = int(sampling_ctx.channel_index.shape[0]) logits = repeat(logits, "f co r ci -> b f co r ci", b=batch_size) # logits shape: (B, D, O, R, C) # Select output channel based on parent's channel_index channel_idx = sampling_ctx.channel_index # Gather the correct output channel num_repetitions = self.out_shape.repetitions num_input_channels = self._in_channels idx = repeat_channel_index( channel_idx, "b f co -> b f co r ci", r=num_repetitions, ci=num_input_channels, ) logits = index_tensor( logits, index=idx, dim=2, is_differentiable=sampling_ctx.is_differentiable, ) # logits shape: (B, D, R, C) # Select repetition if specified num_features = self.out_shape.features rep_idx = repeat_repetition_index( sampling_ctx.repetition_index, "b r -> b f r ci", f=num_features, ci=num_input_channels, ) logits = index_tensor( logits, index=rep_idx, dim=2, is_differentiable=sampling_ctx.is_differentiable, ) # logits shape: (B, D, C) # Condition on evidence if cache has log-likelihoods. left_ll = None right_ll = None if "log_likelihood" in cache: if self._two_inputs: left_ll = cache["log_likelihood"].get(self.inputs[0]) right_ll = cache["log_likelihood"].get(self.inputs[1]) else: split_ll = cache["log_likelihood"].get(self.inputs) if isinstance(split_ll, (list, tuple)) and len(split_ll) == 2: left_ll, right_ll = split_ll if left_ll is not None and right_ll is not None: # Select repetition num_features = int(left_ll.shape[1]) num_input_channels = int(left_ll.shape[2]) rep_idx_l = repeat_repetition_index( sampling_ctx.repetition_index, "b r -> b f ci r", f=num_features, ci=num_input_channels, ) left_ll = index_tensor( left_ll, index=rep_idx_l, dim=-1, is_differentiable=sampling_ctx.is_differentiable, ) right_ll = index_tensor( right_ll, index=rep_idx_l, dim=-1, is_differentiable=sampling_ctx.is_differentiable, ) # Product log-likelihood for each channel prod_ll = left_ll + right_ll # (B, D, C) # Compute posterior log_prior = logits log_posterior = log_prior + prod_ll log_posterior = log_posterior - torch.logsumexp(log_posterior, dim=-1, keepdim=True) logits = log_posterior indices = sample_from_logits( logits=logits, dim=-1, is_mpe=sampling_ctx.is_mpe, is_differentiable=sampling_ctx.is_differentiable, tau=sampling_ctx.tau, ) # Sample from left and right children with same channel index # (Linear combination means left and right use the same channel) if self._two_inputs: # Left child left_ctx = sampling_ctx.with_routing(channel_index=indices, mask=sampling_ctx.mask) self.inputs[0]._sample(data=data, cache=cache, sampling_ctx=left_ctx) # Right child right_ctx = sampling_ctx.with_routing(channel_index=indices, mask=sampling_ctx.mask) self.inputs[1]._sample(data=data, cache=cache, sampling_ctx=right_ctx) else: # Single input with Split module - use generic merge_split_indices # For LinsumLayer, both left and right use the same indices (linear combination) full_indices = self.inputs.merge_split_tensors(indices, indices) full_mask = repeat(sampling_ctx.mask, "b f -> b (f two)", two=2) child_ctx = sampling_ctx.with_routing(channel_index=full_indices, mask=full_mask) self.inputs._sample(data=data, cache=cache, sampling_ctx=child_ctx) return data def _expectation_maximization_step( self, data: Tensor, bias_correction: bool = True, *, cache: Cache, ) -> None: """Perform EM step to update weights. Args: data: Training data tensor. bias_correction: Whether to apply bias correction. cache: Cache with log-likelihoods. """ with torch.no_grad(): # Get cached values left_ll, right_ll = self._get_left_right_ll(data, cache) module_lls = cache["log_likelihood"].get(self) if module_lls is None: raise MissingCacheError("Module log-likelihoods not in cache. Call log_likelihood first.") # E-step: compute expected counts log_weights = rearrange(self.log_weights, "f co r ci -> 1 f co r ci") # Product of left and right prod_ll = left_ll + right_ll # (B, D, C, R) # Rearrange to match weights: (B, D, O, R, C) prod_ll = rearrange(prod_ll, "b f ci r -> b f 1 r ci") # Get gradients log_grads = torch.log(module_lls.grad + 1e-10) log_grads = rearrange(log_grads, "b f co r -> b f co r 1") module_lls = rearrange(module_lls, "b f co r -> b f co r 1") # Compute log expectations log_expectations = log_weights + log_grads + prod_ll - module_lls log_expectations = log_expectations.logsumexp(0) # Sum over batch # Normalize to get new log weights new_log_weights = torch.nn.functional.log_softmax(log_expectations, dim=-1) # M-step: update weights self.log_weights = new_log_weights # Recurse to children if self._two_inputs: self.inputs[0]._expectation_maximization_step(data, bias_correction=bias_correction, cache=cache) self.inputs[1]._expectation_maximization_step(data, bias_correction=bias_correction, cache=cache) else: self.inputs.inputs._expectation_maximization_step( data, bias_correction=bias_correction, cache=cache )
[docs] def marginalize( self, marg_rvs: list[int], prune: bool = True, cache: Cache | None = None, ) -> Optional["LinsumLayer" | Module]: """Marginalize out specified random variables. Args: marg_rvs: Random variable indices to marginalize. prune: Whether to prune unnecessary modules. cache: Cache for memoization. Returns: Marginalized module or None if fully marginalized. """ module_scope = self.scope mutual_rvs = set(module_scope.query).intersection(set(marg_rvs)) # Fully marginalized if len(mutual_rvs) == len(module_scope.query): return None # No overlap - return self unchanged if not mutual_rvs: return self # Partially marginalized if self._two_inputs: left_marg = self.inputs[0].marginalize(marg_rvs, prune=prune, cache=cache) right_marg = self.inputs[1].marginalize(marg_rvs, prune=prune, cache=cache) if left_marg is None and right_marg is None: return None elif left_marg is None: return right_marg elif right_marg is None: return left_marg else: # Both still exist - create new LinsumLayer with marginalized children return LinsumLayer( inputs=[left_marg, right_marg], out_channels=self.out_shape.channels, num_repetitions=self.out_shape.repetitions, ) else: # Single input - marginalize the underlying input marg_input = self.inputs.inputs.marginalize(marg_rvs, prune=prune, cache=cache) if marg_input is None: return None # Check if we still have enough features for LinsumLayer if marg_input.out_shape.features < 2: return marg_input if marg_input.out_shape.features % 2 != 0: # Odd number of features - can't use LinsumLayer return marg_input return LinsumLayer( inputs=marg_input, out_channels=self.out_shape.channels, num_repetitions=self.out_shape.repetitions, )