"""LinsumLayer for efficient linear sum-product operations in probabilistic circuits.
Unlike EinsumLayer which computes a cross-product of input channels,
LinsumLayer computes a linear combination: it adds left/right features
(product in log-space), then applies a weighted sum over input channels.
"""
from __future__ import annotations
from typing import Optional
import numpy as np
import torch
from einops import rearrange, repeat
from torch import Tensor, nn
from spflow.exceptions import (
InvalidWeightsError,
MissingCacheError,
ScopeError,
ShapeError,
)
from spflow.meta.data import Scope
from spflow.modules.module import Module
from spflow.modules.module_shape import ModuleShape
from spflow.modules.ops.split import Split, SplitMode
from spflow.modules.ops.split_consecutive import SplitConsecutive
from spflow.utils.cache import Cache, cached
from spflow.utils.projections import proj_convex_to_real
from spflow.utils.sampling_context import (
SamplingContext,
index_tensor,
repeat_channel_index,
repeat_repetition_index,
sample_from_logits,
)
[docs]
class LinsumLayer(Module):
"""LinsumLayer combining product and sum operations with linear channel combination.
Unlike EinsumLayer which computes cross-product over channels (I × J combinations),
LinsumLayer computes a linear combination: pairs left/right features, adds them
(product in log-space), then sums over input channels with learned weights.
This results in fewer parameters: weight_shape = (D_out, O, R, C) vs
EinsumLayer's (D_out, O, R, I, J).
Attributes:
logits (Parameter): Unnormalized log-weights for gradient optimization.
"""
[docs]
def __init__(
self,
inputs: Module | list[Module],
out_channels: int,
num_repetitions: int | None = None,
weights: Tensor | None = None,
split_mode: SplitMode | None = None,
) -> None:
"""Initialize LinsumLayer.
Args:
inputs: Either a single module (features will be split into pairs)
or a list of exactly two modules (left and right children).
Unlike EinsumLayer, both inputs must have the same number of channels.
out_channels: Number of output sum nodes per feature.
num_repetitions: Number of repetitions. If None, inferred from inputs.
weights: Optional initial weights tensor. If provided, must have shape
(out_features, out_channels, num_repetitions, in_channels).
split_mode: Optional split configuration for single input mode.
Use SplitMode.consecutive() or SplitMode.interleaved().
Defaults to SplitMode.consecutive(num_splits=2) if not specified.
Raises:
ValueError: If inputs invalid, out_channels < 1, or weight shape mismatch.
"""
super().__init__()
# ========== 1. INPUT VALIDATION ==========
if isinstance(inputs, list):
if len(inputs) != 2:
raise ValueError(
f"LinsumLayer requires exactly 2 input modules when given a list, got {len(inputs)}."
)
self._two_inputs = True
left_input, right_input = inputs
# LinsumLayer requires same number of channels (linear combination, not cross-product)
if left_input.out_shape.channels != right_input.out_shape.channels:
raise ValueError(
f"LinsumLayer requires left and right inputs to have same number of channels: "
f"{left_input.out_shape.channels} != {right_input.out_shape.channels}"
)
if left_input.out_shape.features != right_input.out_shape.features:
raise ValueError(
f"Left and right inputs must have same number of features: "
f"{left_input.out_shape.features} != {right_input.out_shape.features}"
)
if left_input.out_shape.repetitions != right_input.out_shape.repetitions:
raise ValueError(
f"Left and right inputs must have same number of repetitions: "
f"{left_input.out_shape.repetitions} != {right_input.out_shape.repetitions}"
)
# Validate disjoint scopes
if not Scope.all_pairwise_disjoint([left_input.scope, right_input.scope]):
raise ScopeError("Left and right input scopes must be disjoint.")
self.inputs = nn.ModuleList([left_input, right_input])
in_channels = left_input.out_shape.channels
in_features = left_input.out_shape.features
if num_repetitions is None:
num_repetitions = left_input.out_shape.repetitions
self.scope = Scope.join_all([left_input.scope, right_input.scope])
else:
# Single input: will split features into left/right halves
self._two_inputs = False
if inputs.out_shape.features < 2:
raise ValueError(
f"LinsumLayer requires at least 2 input features for splitting, "
f"got {inputs.out_shape.features}."
)
if inputs.out_shape.features % 2 != 0:
raise ValueError(
f"LinsumLayer requires even number of input features for splitting, "
f"got {inputs.out_shape.features}."
)
# Use Split directly if already a split module, otherwise create from split_mode
if isinstance(inputs, Split):
self.inputs = inputs
elif split_mode is not None:
self.inputs = split_mode.create(inputs)
else:
# Default: consecutive split with 2 parts
self.inputs = SplitConsecutive(inputs)
in_channels = inputs.out_shape.channels
in_features = inputs.out_shape.features // 2
if num_repetitions is None:
num_repetitions = inputs.out_shape.repetitions
self.scope = inputs.scope
# ========== 2. CONFIGURATION VALIDATION ==========
if out_channels < 1:
raise ValueError(f"out_channels must be >= 1, got {out_channels}.")
# ========== 3. SHAPE COMPUTATION ==========
self._in_channels = in_channels
self.in_shape = ModuleShape(in_features, in_channels, num_repetitions)
self.out_shape = ModuleShape(in_features, out_channels, num_repetitions)
# ========== 4. WEIGHT INITIALIZATION ==========
# Linear sum: weight over input channels only (not cross-product)
self.weights_shape = (
self.out_shape.features, # D_out
self.out_shape.channels, # O (output channels)
self.out_shape.repetitions, # R
self._in_channels, # C (input channels - linear, not cross-product)
)
if weights is None:
# Initialize weights randomly, normalized over input channels
weights = torch.rand(self.weights_shape) + 1e-08
weights = weights / weights.sum(dim=-1, keepdim=True)
# Validate weights shape
if weights.shape != self.weights_shape:
raise ValueError(f"Weight shape mismatch: expected {self.weights_shape}, got {weights.shape}")
# Register logits parameter
self.logits = nn.Parameter(torch.zeros(self.weights_shape))
# Set weights via property (converts to logits)
self.weights = weights
@property
def feature_to_scope(self) -> np.ndarray:
"""Mapping from output features to their scopes."""
if self._two_inputs:
# Combine scopes from left and right inputs
left_scopes = self.inputs[0].feature_to_scope
right_scopes = self.inputs[1].feature_to_scope
combined = []
for r in range(self.out_shape.repetitions):
rep_scopes = []
for f in range(self.out_shape.features):
left_s = left_scopes[f, r]
right_s = right_scopes[f, r]
rep_scopes.append(Scope.join_all([left_s, right_s]))
combined.append(np.array(rep_scopes))
return np.stack(combined, axis=1)
else:
# Single input split into halves - combine adjacent pairs
input_scopes = self.inputs.inputs.feature_to_scope
combined = []
for r in range(self.out_shape.repetitions):
rep_scopes = []
for f in range(self.out_shape.features):
left_s = input_scopes[2 * f, r]
right_s = input_scopes[2 * f + 1, r]
rep_scopes.append(Scope.join_all([left_s, right_s]))
combined.append(np.array(rep_scopes))
return np.stack(combined, axis=1)
@property
def log_weights(self) -> Tensor:
"""Log-normalized weights (sum to 1 over input channels)."""
return torch.nn.functional.log_softmax(self.logits, dim=-1)
@property
def weights(self) -> Tensor:
"""Normalized weights (sum to 1 over input channels)."""
return torch.nn.functional.softmax(self.logits, dim=-1)
@weights.setter
def weights(self, values: Tensor) -> None:
"""Set weights (must be positive and sum to 1 over channels)."""
if values.shape != self.weights_shape:
raise ShapeError(f"Weight shape mismatch: expected {self.weights_shape}, got {values.shape}")
if not torch.all(values > 0):
raise InvalidWeightsError("Weights must be positive.")
sums = values.sum(dim=-1)
if not torch.allclose(sums, torch.ones_like(sums)):
raise InvalidWeightsError("Weights must sum to 1 over input channels.")
# Project to logits space
self.logits.data = proj_convex_to_real(values)
@log_weights.setter
def log_weights(self, values: Tensor) -> None:
"""Set log weights directly."""
if values.shape != self.weights_shape:
raise ShapeError(f"Log weight shape mismatch: expected {self.weights_shape}, got {values.shape}")
self.logits.data = values
def extra_repr(self) -> str:
return f"{super().extra_repr()}, weights={self.weights_shape}"
def _get_left_right_ll(self, data: Tensor, cache: Cache | None = None) -> tuple[Tensor, Tensor]:
"""Get log-likelihoods from left and right children.
Returns:
Tuple of (left_ll, right_ll), each of shape (batch, features, channels, reps).
"""
if self._two_inputs:
left_ll = self.inputs[0].log_likelihood(data, cache=cache)
right_ll = self.inputs[1].log_likelihood(data, cache=cache)
else:
# SplitConsecutive returns list of [left, right]
lls = self.inputs.log_likelihood(data, cache=cache)
left_ll = lls[0]
right_ll = lls[1]
return left_ll, right_ll
[docs]
@cached
def log_likelihood(
self,
data: Tensor,
cache: Cache | None = None,
) -> Tensor:
"""Compute log-likelihood using linear sum over channels.
Unlike EinsumLayer which computes cross-product (I × J), this computes
a linear combination: add left+right (product), then logsumexp over channels.
Args:
data: Input data of shape (batch_size, num_features).
cache: Optional cache for intermediate results.
Returns:
Log-likelihood tensor of shape (batch, out_features, out_channels, reps).
"""
# Get child log-likelihoods
left_ll, right_ll = self._get_left_right_ll(data, cache)
# Dimensions: N=batch, D=features, C=channels, R=reps
N, D, C, R = left_ll.size()
# Product: left + right in log-space
# Shape: (N, D, C, R)
prod_ll = left_ll + right_ll
# Expand for output channels dimension
# prod_ll: (N, D, C, R) -> (N, D, 1, C, R)
prod_ll = rearrange(prod_ll, "n f ci r -> n f 1 ci r")
# Get log weights: (D, O, R, C) -> (1, D, O, C, R)
log_weights = rearrange(self.log_weights, "f co r ci -> 1 f co ci r")
# Weighted sum over input channels
# (N, D, 1, C, R) + (1, D, O, C, R) -> (N, D, O, C, R)
weighted_ll = prod_ll + log_weights
# LogSumExp over input channels (dim=3)
log_prob = torch.logsumexp(weighted_ll, dim=3) # (N, D, O, R)
return log_prob
def _sample(
self,
data: Tensor,
sampling_ctx: SamplingContext,
cache: Cache,
) -> Tensor:
"""Sample from the LinsumLayer.
Args:
num_samples: Number of samples to generate.
data: Optional data tensor with evidence (NaN for missing).
is_mpe: Whether to perform MPE instead of sampling.
cache: Optional cache with log-likelihoods for conditional sampling.
sampling_ctx: Sampling context with channel indices.
Returns:
Sampled data tensor.
"""
# Prepare data tensor
sampling_ctx.validate_sampling_context(
num_samples=data.shape[0],
num_features=self.out_shape.features,
num_channels=self.out_shape.channels,
num_repetitions=self.out_shape.repetitions,
allowed_feature_widths=(1, self.out_shape.features),
)
sampling_ctx.broadcast_feature_width(target_features=self.out_shape.features, allow_from_one=True)
# Get logits and select based on context
logits = self.logits # (D, O, R, C)
# Expand for batch dimension
batch_size = int(sampling_ctx.channel_index.shape[0])
logits = repeat(logits, "f co r ci -> b f co r ci", b=batch_size)
# logits shape: (B, D, O, R, C)
# Select output channel based on parent's channel_index
channel_idx = sampling_ctx.channel_index
# Gather the correct output channel
num_repetitions = self.out_shape.repetitions
num_input_channels = self._in_channels
idx = repeat_channel_index(
channel_idx,
"b f co -> b f co r ci",
r=num_repetitions,
ci=num_input_channels,
)
logits = index_tensor(
logits,
index=idx,
dim=2,
is_differentiable=sampling_ctx.is_differentiable,
)
# logits shape: (B, D, R, C)
# Select repetition if specified
num_features = self.out_shape.features
rep_idx = repeat_repetition_index(
sampling_ctx.repetition_index,
"b r -> b f r ci",
f=num_features,
ci=num_input_channels,
)
logits = index_tensor(
logits,
index=rep_idx,
dim=2,
is_differentiable=sampling_ctx.is_differentiable,
)
# logits shape: (B, D, C)
# Condition on evidence if cache has log-likelihoods.
left_ll = None
right_ll = None
if "log_likelihood" in cache:
if self._two_inputs:
left_ll = cache["log_likelihood"].get(self.inputs[0])
right_ll = cache["log_likelihood"].get(self.inputs[1])
else:
split_ll = cache["log_likelihood"].get(self.inputs)
if isinstance(split_ll, (list, tuple)) and len(split_ll) == 2:
left_ll, right_ll = split_ll
if left_ll is not None and right_ll is not None:
# Select repetition
num_features = int(left_ll.shape[1])
num_input_channels = int(left_ll.shape[2])
rep_idx_l = repeat_repetition_index(
sampling_ctx.repetition_index,
"b r -> b f ci r",
f=num_features,
ci=num_input_channels,
)
left_ll = index_tensor(
left_ll,
index=rep_idx_l,
dim=-1,
is_differentiable=sampling_ctx.is_differentiable,
)
right_ll = index_tensor(
right_ll,
index=rep_idx_l,
dim=-1,
is_differentiable=sampling_ctx.is_differentiable,
)
# Product log-likelihood for each channel
prod_ll = left_ll + right_ll # (B, D, C)
# Compute posterior
log_prior = logits
log_posterior = log_prior + prod_ll
log_posterior = log_posterior - torch.logsumexp(log_posterior, dim=-1, keepdim=True)
logits = log_posterior
indices = sample_from_logits(
logits=logits,
dim=-1,
is_mpe=sampling_ctx.is_mpe,
is_differentiable=sampling_ctx.is_differentiable,
tau=sampling_ctx.tau,
)
# Sample from left and right children with same channel index
# (Linear combination means left and right use the same channel)
if self._two_inputs:
# Left child
left_ctx = sampling_ctx.with_routing(channel_index=indices, mask=sampling_ctx.mask)
self.inputs[0]._sample(data=data, cache=cache, sampling_ctx=left_ctx)
# Right child
right_ctx = sampling_ctx.with_routing(channel_index=indices, mask=sampling_ctx.mask)
self.inputs[1]._sample(data=data, cache=cache, sampling_ctx=right_ctx)
else:
# Single input with Split module - use generic merge_split_indices
# For LinsumLayer, both left and right use the same indices (linear combination)
full_indices = self.inputs.merge_split_tensors(indices, indices)
full_mask = repeat(sampling_ctx.mask, "b f -> b (f two)", two=2)
child_ctx = sampling_ctx.with_routing(channel_index=full_indices, mask=full_mask)
self.inputs._sample(data=data, cache=cache, sampling_ctx=child_ctx)
return data
def _expectation_maximization_step(
self,
data: Tensor,
bias_correction: bool = True,
*,
cache: Cache,
) -> None:
"""Perform EM step to update weights.
Args:
data: Training data tensor.
bias_correction: Whether to apply bias correction.
cache: Cache with log-likelihoods.
"""
with torch.no_grad():
# Get cached values
left_ll, right_ll = self._get_left_right_ll(data, cache)
module_lls = cache["log_likelihood"].get(self)
if module_lls is None:
raise MissingCacheError("Module log-likelihoods not in cache. Call log_likelihood first.")
# E-step: compute expected counts
log_weights = rearrange(self.log_weights, "f co r ci -> 1 f co r ci")
# Product of left and right
prod_ll = left_ll + right_ll # (B, D, C, R)
# Rearrange to match weights: (B, D, O, R, C)
prod_ll = rearrange(prod_ll, "b f ci r -> b f 1 r ci")
# Get gradients
log_grads = torch.log(module_lls.grad + 1e-10)
log_grads = rearrange(log_grads, "b f co r -> b f co r 1")
module_lls = rearrange(module_lls, "b f co r -> b f co r 1")
# Compute log expectations
log_expectations = log_weights + log_grads + prod_ll - module_lls
log_expectations = log_expectations.logsumexp(0) # Sum over batch
# Normalize to get new log weights
new_log_weights = torch.nn.functional.log_softmax(log_expectations, dim=-1)
# M-step: update weights
self.log_weights = new_log_weights
# Recurse to children
if self._two_inputs:
self.inputs[0]._expectation_maximization_step(data, bias_correction=bias_correction, cache=cache)
self.inputs[1]._expectation_maximization_step(data, bias_correction=bias_correction, cache=cache)
else:
self.inputs.inputs._expectation_maximization_step(
data, bias_correction=bias_correction, cache=cache
)
[docs]
def marginalize(
self,
marg_rvs: list[int],
prune: bool = True,
cache: Cache | None = None,
) -> Optional["LinsumLayer" | Module]:
"""Marginalize out specified random variables.
Args:
marg_rvs: Random variable indices to marginalize.
prune: Whether to prune unnecessary modules.
cache: Cache for memoization.
Returns:
Marginalized module or None if fully marginalized.
"""
module_scope = self.scope
mutual_rvs = set(module_scope.query).intersection(set(marg_rvs))
# Fully marginalized
if len(mutual_rvs) == len(module_scope.query):
return None
# No overlap - return self unchanged
if not mutual_rvs:
return self
# Partially marginalized
if self._two_inputs:
left_marg = self.inputs[0].marginalize(marg_rvs, prune=prune, cache=cache)
right_marg = self.inputs[1].marginalize(marg_rvs, prune=prune, cache=cache)
if left_marg is None and right_marg is None:
return None
elif left_marg is None:
return right_marg
elif right_marg is None:
return left_marg
else:
# Both still exist - create new LinsumLayer with marginalized children
return LinsumLayer(
inputs=[left_marg, right_marg],
out_channels=self.out_shape.channels,
num_repetitions=self.out_shape.repetitions,
)
else:
# Single input - marginalize the underlying input
marg_input = self.inputs.inputs.marginalize(marg_rvs, prune=prune, cache=cache)
if marg_input is None:
return None
# Check if we still have enough features for LinsumLayer
if marg_input.out_shape.features < 2:
return marg_input
if marg_input.out_shape.features % 2 != 0:
# Odd number of features - can't use LinsumLayer
return marg_input
return LinsumLayer(
inputs=marg_input,
out_channels=self.out_shape.channels,
num_repetitions=self.out_shape.repetitions,
)