"""EinsumLayer for efficient sum-product operations in probabilistic circuits.
This layer evaluates weighted sums over pairwise channel products in log space.
The numerically stable outer structure still follows the LogEinsumExp trick,
but the inner contraction is scheduled to limit memory traffic on large channel
grids:
- normalize routing weights over the flattened ``(left_channel, right_channel)``
axis
- keep an eval/no-grad cache of those normalized weights so inference does not
pay for a repeated softmax when parameters are unchanged
- use a factored two-stage contraction when the full ``left x right`` pair grid
is large enough that a single einsum becomes memory-bandwidth bound
The factored contraction is mathematically equivalent to the dense einsum, but
it maps better to the backends that power large CPU/GPU workloads and reduces
allocator pressure in the wide-channel regimes where ``EinsumLayer`` is hottest.
"""
from __future__ import annotations
from typing import Optional
import numpy as np
import torch
from einops import rearrange, repeat
from torch import Tensor, nn
from spflow.exceptions import (
InvalidWeightsError,
MissingCacheError,
ScopeError,
ShapeError,
)
from spflow.meta.data import Scope
from spflow.modules.module import Module
from spflow.modules.module_shape import ModuleShape
from spflow.modules.ops.split import Split, SplitMode
from spflow.modules.ops.split_consecutive import SplitConsecutive
from spflow.utils.cache import Cache, cached
from spflow.utils.projections import proj_convex_to_real
from spflow.utils.sampling_context import (
SamplingContext,
index_tensor,
repeat_channel_index,
repeat_repetition_index,
sample_from_logits,
)
[docs]
class EinsumLayer(Module):
"""EinsumLayer combining product and sum operations efficiently.
Implements sum(product(x)) using einsum for circuits with arbitrary tree
structure. Takes pairs of adjacent features as left/right children, computes
their cross-product over channels, and sums with learned weights.
The LogEinsumExp trick is used for numerical stability in log-space. For
large channel grids, the weighted channel contraction is factorized into two
smaller contractions to reduce bandwidth pressure without changing outputs.
Attributes:
logits (Parameter): Unnormalized log-weights for gradient optimization.
unraveled_channel_indices (Tensor): Mapping from flat to (i,j) channel pairs.
"""
[docs]
def __init__(
self,
inputs: Module | list[Module],
out_channels: int,
num_repetitions: int | None = None,
weights: Tensor | None = None,
split_mode: SplitMode | None = None,
) -> None:
"""Initialize EinsumLayer.
Args:
inputs: Either a single module (features will be split into pairs)
or a list of exactly two modules (left and right children).
out_channels: Number of output sum nodes per feature.
num_repetitions: Number of repetitions. If None, inferred from inputs.
weights: Optional initial weights tensor. If provided, must have shape
(out_features, out_channels, num_repetitions, left_channels, right_channels).
split_mode: Optional split configuration for single input mode.
Use SplitMode.consecutive() or SplitMode.interleaved().
Defaults to SplitMode.consecutive(num_splits=2) if not specified.
Raises:
ValueError: If inputs invalid, out_channels < 1, or weight shape mismatch.
"""
super().__init__()
# ========== 1. INPUT VALIDATION ==========
if isinstance(inputs, list):
if len(inputs) != 2:
raise ValueError(
f"EinsumLayer requires exactly 2 input modules when given a list, got {len(inputs)}."
)
self._two_inputs = True
left_input, right_input = inputs
# Validate compatible shapes (channels can differ for cross-product)
if left_input.out_shape.features != right_input.out_shape.features:
raise ValueError(
f"Left and right inputs must have same number of features: "
f"{left_input.out_shape.features} != {right_input.out_shape.features}"
)
if left_input.out_shape.repetitions != right_input.out_shape.repetitions:
raise ValueError(
f"Left and right inputs must have same number of repetitions: "
f"{left_input.out_shape.repetitions} != {right_input.out_shape.repetitions}"
)
# Validate disjoint scopes
if not Scope.all_pairwise_disjoint([left_input.scope, right_input.scope]):
raise ScopeError("Left and right input scopes must be disjoint.")
self.inputs = nn.ModuleList([left_input, right_input])
self._left_channels = left_input.out_shape.channels
self._right_channels = right_input.out_shape.channels
in_features = left_input.out_shape.features
if num_repetitions is None:
num_repetitions = left_input.out_shape.repetitions
self.scope = Scope.join_all([left_input.scope, right_input.scope])
else:
# Single input: will split features into left/right halves
self._two_inputs = False
if inputs.out_shape.features < 2:
raise ValueError(
f"EinsumLayer requires at least 2 input features for splitting, "
f"got {inputs.out_shape.features}."
)
if inputs.out_shape.features % 2 != 0:
raise ValueError(
f"EinsumLayer requires even number of input features for splitting, "
f"got {inputs.out_shape.features}."
)
# Use Split directly if already a split module, otherwise create from split_mode
if isinstance(inputs, Split):
self.inputs = inputs
elif split_mode is not None:
self.inputs = split_mode.create(inputs)
else:
# Default: consecutive split with 2 parts
self.inputs = SplitConsecutive(inputs)
self._left_channels = inputs.out_shape.channels
self._right_channels = inputs.out_shape.channels
in_features = inputs.out_shape.features // 2
if num_repetitions is None:
num_repetitions = inputs.out_shape.repetitions
self.scope = inputs.scope
# ========== 2. CONFIGURATION VALIDATION ==========
if out_channels < 1:
raise ValueError(f"out_channels must be >= 1, got {out_channels}.")
# ========== 3. SHAPE COMPUTATION ==========
# Use max channels for in_shape (for informational purposes)
max_in_channels = max(self._left_channels, self._right_channels)
self.in_shape = ModuleShape(in_features, max_in_channels, num_repetitions)
self.out_shape = ModuleShape(in_features, out_channels, num_repetitions)
# ========== 4. WEIGHT INITIALIZATION ==========
self.weights_shape = (
self.out_shape.features, # D_out
self.out_shape.channels, # O (output channels)
self.out_shape.repetitions, # R
self._left_channels, # I (left input channels)
self._right_channels, # J (right input channels)
)
# Create index mapping for sampling: flatten (i,j) -> idx and back
self.register_buffer(
"unraveled_channel_indices",
torch.tensor(
[(i, j) for i in range(self._left_channels) for j in range(self._right_channels)],
dtype=torch.long,
),
)
self.register_buffer(
"flat_to_left_channels",
torch.nn.functional.one_hot(
self.unraveled_channel_indices[:, 0],
num_classes=self._left_channels,
).to(dtype=torch.get_default_dtype()),
)
self.register_buffer(
"flat_to_right_channels",
torch.nn.functional.one_hot(
self.unraveled_channel_indices[:, 1],
num_classes=self._right_channels,
).to(dtype=torch.get_default_dtype()),
)
if weights is None:
# Initialize weights randomly, normalized over (i,j) pairs
weights = torch.rand(self.weights_shape) + 1e-08
weights = weights / weights.sum(dim=(-2, -1), keepdim=True)
# Validate weights shape
if weights.shape != self.weights_shape:
raise ValueError(f"Weight shape mismatch: expected {self.weights_shape}, got {weights.shape}")
# Register logits parameter
self.logits = nn.Parameter(torch.zeros(self.weights_shape))
self._eval_weight_cache: tuple[int, Tensor] | None = None
# Set weights via property (converts to logits)
self.weights = weights
@property
def feature_to_scope(self) -> np.ndarray:
"""Mapping from output features to their scopes."""
if self._two_inputs:
# Combine scopes from left and right inputs
left_scopes = self.inputs[0].feature_to_scope
right_scopes = self.inputs[1].feature_to_scope
combined = []
for r in range(self.out_shape.repetitions):
rep_scopes = []
for f in range(self.out_shape.features):
left_s = left_scopes[f, r]
right_s = right_scopes[f, r]
rep_scopes.append(Scope.join_all([left_s, right_s]))
combined.append(np.array(rep_scopes))
return np.stack(combined, axis=1)
else:
# Single input split into halves - combine adjacent pairs
input_scopes = self.inputs.inputs.feature_to_scope
combined = []
for r in range(self.out_shape.repetitions):
rep_scopes = []
for f in range(self.out_shape.features):
left_s = input_scopes[2 * f, r]
right_s = input_scopes[2 * f + 1, r]
rep_scopes.append(Scope.join_all([left_s, right_s]))
combined.append(np.array(rep_scopes))
return np.stack(combined, axis=1)
@property
def log_weights(self) -> Tensor:
"""Log-normalized weights (sum to 1 over input channel pairs)."""
flat_logits = self.logits.flatten(start_dim=-2)
log_weights = torch.nn.functional.log_softmax(flat_logits, dim=-1)
return log_weights.unflatten(-1, (self._left_channels, self._right_channels))
@property
def weights(self) -> Tensor:
"""Normalized weights (sum to 1 over input channel pairs)."""
return self._normalized_weights()
@weights.setter
def weights(self, values: Tensor) -> None:
"""Set weights (must be positive and sum to 1 over i,j pairs)."""
if values.shape != self.weights_shape:
raise ShapeError(f"Weight shape mismatch: expected {self.weights_shape}, got {values.shape}")
if not torch.all(values > 0):
raise InvalidWeightsError("Weights must be positive.")
sums = values.sum(dim=(-2, -1))
if not torch.allclose(sums, torch.ones_like(sums)):
raise InvalidWeightsError("Weights must sum to 1 over (i,j) channel pairs.")
# Project to logits space
flat_weights = rearrange(values, "f co r i j -> f co r (i j)")
flat_logits = proj_convex_to_real(flat_weights)
self.logits.data = rearrange(
flat_logits,
"f co r (i j) -> f co r i j",
i=self._left_channels,
j=self._right_channels,
)
self._clear_eval_weight_cache()
@log_weights.setter
def log_weights(self, values: Tensor) -> None:
"""Set log weights directly."""
if values.shape != self.weights_shape:
raise ShapeError(f"Log weight shape mismatch: expected {self.weights_shape}, got {values.shape}")
self.logits.data = values
self._clear_eval_weight_cache()
def extra_repr(self) -> str:
return f"{super().extra_repr()}, weights={self.weights_shape}"
def _get_left_right_ll(self, data: Tensor, cache: Cache | None = None) -> tuple[Tensor, Tensor]:
"""Get log-likelihoods from left and right children.
Returns:
Tuple of (left_ll, right_ll), each of shape (batch, features, channels, reps).
"""
if self._two_inputs:
left_ll = self.inputs[0].log_likelihood(data, cache=cache)
right_ll = self.inputs[1].log_likelihood(data, cache=cache)
else:
# SplitConsecutive returns list of [left, right]
lls = self.inputs.log_likelihood(data, cache=cache)
left_ll = lls[0]
right_ll = lls[1]
return left_ll, right_ll
def _clear_eval_weight_cache(self) -> None:
"""Drop the eval/no-grad normalized-weight cache."""
self._eval_weight_cache = None
def _normalized_weights(self) -> Tensor:
"""Return normalized routing weights with eval/no-grad caching.
Training and gradient-enabled execution always recompute weights so
autograd sees the live parameter graph. Eval/no-grad inference can
safely reuse the normalized tensor until ``self.logits`` changes.
"""
if self.training or torch.is_grad_enabled():
flat_logits = self.logits.flatten(start_dim=-2)
weights = torch.nn.functional.softmax(flat_logits, dim=-1)
return weights.unflatten(-1, (self._left_channels, self._right_channels))
cache_entry = self._eval_weight_cache
current_version = self.logits._version
if cache_entry is not None:
cached_version, cached_weights = cache_entry
if (
cached_version == current_version
and cached_weights.device == self.logits.device
and cached_weights.dtype == self.logits.dtype
):
return cached_weights
flat_logits = self.logits.flatten(start_dim=-2)
weights = torch.nn.functional.softmax(flat_logits, dim=-1)
normalized = weights.unflatten(-1, (self._left_channels, self._right_channels))
self._eval_weight_cache = (current_version, normalized)
return normalized
def _use_factored_probability_contraction(self, device: torch.device) -> bool:
"""Choose the factored contraction only where it is empirically beneficial.
A single dense einsum is fine for small pair grids. Once the
``left_channels * right_channels`` space gets large, the contraction
becomes dominated by memory movement rather than math. The factored path
helps most when the output channel count is small relative to the input
channel counts, especially on CUDA.
"""
input_pair_count = self._left_channels * self._right_channels
if device.type == "cuda":
return input_pair_count >= 128
return input_pair_count >= 256 and self.out_shape.channels < min(
self._left_channels, self._right_channels
)
def _contract_probabilities(self, left_prob: Tensor, right_prob: Tensor, weights: Tensor) -> Tensor:
"""Contract channel probabilities using the most appropriate schedule.
The factored two-stage contraction is equivalent to the dense
``ndir,ndjr,dorij->ndor`` einsum, but it keeps the widest intermediates
out of the hottest execution path for large channel products.
"""
if not self._use_factored_probability_contraction(left_prob.device):
return torch.einsum("ndir,ndjr,dorij->ndor", left_prob, right_prob, weights)
if self._left_channels >= self._right_channels:
tmp = torch.einsum("ndir,dorij->ndjor", left_prob, weights)
return torch.einsum("ndjr,ndjor->ndor", right_prob, tmp)
tmp = torch.einsum("ndjr,dorij->ndior", right_prob, weights)
return torch.einsum("ndir,ndior->ndor", left_prob, tmp)
def _log_likelihood_from_inputs(self, left_ll: Tensor, right_ll: Tensor) -> Tensor:
"""Compute log-likelihoods from child log-likelihood tensors.
The computation stays numerically stable by subtracting per-input channel
maxima before exponentiation. The weighted contraction then operates on
positive-space probabilities using the bandwidth-aware schedule selected
above, and the removed maxima are added back after ``log``.
"""
left_max = left_ll.amax(dim=2, keepdim=True)
left_prob = torch.exp(left_ll - left_max)
right_max = right_ll.amax(dim=2, keepdim=True)
right_prob = torch.exp(right_ll - right_max)
weights = self._normalized_weights()
prob = self._contract_probabilities(left_prob, right_prob, weights)
return torch.log(prob) + left_max + right_max
[docs]
@cached
def log_likelihood(
self,
data: Tensor,
cache: Cache | None = None,
) -> Tensor:
"""Compute log-likelihood using LogEinsumExp trick.
Args:
data: Input data of shape (batch_size, num_features).
cache: Optional cache for intermediate results.
Returns:
Log-likelihood tensor of shape (batch, out_features, out_channels, reps).
"""
# Get child log-likelihoods
left_ll, right_ll = self._get_left_right_ll(data, cache)
return self._log_likelihood_from_inputs(left_ll, right_ll)
def _sample(
self,
data: Tensor,
sampling_ctx: SamplingContext,
cache: Cache,
) -> Tensor:
"""Sample from the EinsumLayer.
Args:
num_samples: Number of samples to generate.
data: Optional data tensor with evidence (NaN for missing).
is_mpe: Whether to perform MPE instead of sampling.
cache: Optional cache with log-likelihoods for conditional sampling.
sampling_ctx: Sampling context with channel indices.
Returns:
Sampled data tensor.
"""
# Prepare data tensor
sampling_ctx.validate_sampling_context(
num_samples=data.shape[0],
num_features=self.out_shape.features,
num_channels=self.out_shape.channels,
num_repetitions=self.out_shape.repetitions,
allowed_feature_widths=(1, self.out_shape.features),
)
sampling_ctx.broadcast_feature_width(target_features=self.out_shape.features, allow_from_one=True)
# Get logits and select based on context
logits = self.logits # (D, O, R, I, J)
# Expand for batch dimension
batch_size = int(sampling_ctx.channel_index.shape[0])
logits = repeat(logits, "f co r i j -> b f co r i j", b=batch_size)
# logits shape: (B, D, O, R, I, J)
# Select output channel based on parent's channel_index
# sampling_ctx.channel_index indexes out_channels.
channel_idx = sampling_ctx.channel_index
# Gather the correct output channel
# Expand channel_idx to match logits dimensions
num_repetitions = self.out_shape.repetitions
num_left_channels = self._left_channels
num_right_channels = self._right_channels
idx = repeat_channel_index(
channel_idx,
"b f co -> b f co r i j",
r=num_repetitions,
i=num_left_channels,
j=num_right_channels,
)
logits = index_tensor(
logits,
index=idx,
dim=2,
is_differentiable=sampling_ctx.is_differentiable,
)
# logits shape: (B, D, R, I, J)
# Select repetition if specified
num_features = self.out_shape.features
rep_idx = repeat_repetition_index(
sampling_ctx.repetition_index,
"b r -> b f r i j",
f=num_features,
i=num_left_channels,
j=num_right_channels,
)
logits = index_tensor(
logits,
index=rep_idx,
dim=2,
is_differentiable=sampling_ctx.is_differentiable,
)
# logits shape: (B, D, I, J)
# Flatten (I, J) for categorical sampling
logits_flat = rearrange(logits, "b f i j -> b f (i j)")
# Condition on evidence if cache has log-likelihoods.
left_ll = None
right_ll = None
if "log_likelihood" in cache:
if self._two_inputs:
left_ll = cache["log_likelihood"].get(self.inputs[0])
right_ll = cache["log_likelihood"].get(self.inputs[1])
else:
split_ll = cache["log_likelihood"].get(self.inputs)
if isinstance(split_ll, (list, tuple)) and len(split_ll) == 2:
left_ll, right_ll = split_ll
if left_ll is not None and right_ll is not None:
# Select repetition
num_features = int(left_ll.shape[1])
num_left_channels = int(left_ll.shape[2])
num_right_channels = int(right_ll.shape[2])
rep_idx_l = repeat_repetition_index(
sampling_ctx.repetition_index,
"b r -> b f i r",
f=num_features,
i=num_left_channels,
)
rep_idx_r = repeat_repetition_index(
sampling_ctx.repetition_index,
"b r -> b f j r",
f=num_features,
j=num_right_channels,
)
left_ll = index_tensor(
left_ll,
index=rep_idx_l,
dim=-1,
is_differentiable=sampling_ctx.is_differentiable,
)
right_ll = index_tensor(
right_ll,
index=rep_idx_r,
dim=-1,
is_differentiable=sampling_ctx.is_differentiable,
)
# Compute joint log-likelihood for each (i, j) pair
# left_ll: (B, D, I), right_ll: (B, D, J)
left_ll = rearrange(left_ll, "b f i -> b f i 1")
right_ll = rearrange(right_ll, "b f j -> b f 1 j")
joint_ll = left_ll + right_ll # (B, D, I, J)
joint_ll_flat = rearrange(joint_ll, "b f i j -> b f (i j)")
# Compute posterior
log_prior = logits_flat
log_posterior = log_prior + joint_ll_flat
log_posterior = log_posterior - torch.logsumexp(log_posterior, dim=-1, keepdim=True)
logits_flat = log_posterior
indices = sample_from_logits(
logits=logits_flat,
dim=-1,
is_mpe=sampling_ctx.is_mpe,
is_differentiable=sampling_ctx.is_differentiable,
tau=sampling_ctx.tau,
)
# Unravel indices to (i, j) pairs
if sampling_ctx.is_differentiable:
left_projection = self.flat_to_left_channels.to(device=indices.device, dtype=indices.dtype)
right_projection = self.flat_to_right_channels.to(device=indices.device, dtype=indices.dtype)
left_indices = indices @ left_projection # (B, D, I)
right_indices = indices @ right_projection # (B, D, J)
else:
ij_indices = self.unraveled_channel_indices[indices] # (B, D, 2)
left_indices = ij_indices[..., 0] # (B, D)
right_indices = ij_indices[..., 1] # (B, D)
# Sample from left and right children
if self._two_inputs:
# Left child
left_ctx = sampling_ctx.with_routing(channel_index=left_indices, mask=sampling_ctx.mask)
self.inputs[0]._sample(data=data, cache=cache, sampling_ctx=left_ctx)
# Right child
right_ctx = sampling_ctx.with_routing(channel_index=right_indices, mask=sampling_ctx.mask)
self.inputs[1]._sample(data=data, cache=cache, sampling_ctx=right_ctx)
else:
# Single input with Split module - use generic merge_split_indices
full_indices = self.inputs.merge_split_tensors(left_indices, right_indices)
full_mask = repeat(sampling_ctx.mask, "b f -> b (f two)", two=2)
child_ctx = sampling_ctx.with_routing(channel_index=full_indices, mask=full_mask)
self.inputs._sample(data=data, cache=cache, sampling_ctx=child_ctx)
return data
def _expectation_maximization_step(
self,
data: Tensor,
bias_correction: bool = True,
*,
cache: Cache,
) -> None:
"""Perform EM step to update weights.
Args:
data: Training data tensor.
bias_correction: Whether to apply bias correction.
cache: Cache with log-likelihoods.
"""
with torch.no_grad():
# Get cached values
left_ll, right_ll = self._get_left_right_ll(data, cache)
module_lls = cache["log_likelihood"].get(self)
if module_lls is None:
raise MissingCacheError("Module log-likelihoods not in cache. Call log_likelihood first.")
# E-step: compute expected counts
log_weights = rearrange(self.log_weights, "f co r i j -> 1 f co r i j")
left_ll = rearrange(left_ll, "b f i r -> b f 1 r i 1")
right_ll = rearrange(right_ll, "b f j r -> b f 1 r 1 j")
# Get gradients (how much each output contributed)
log_grads = torch.log(module_lls.grad + 1e-10)
log_grads = rearrange(log_grads, "b f co r -> b f co r 1 1")
module_lls = rearrange(module_lls, "b f co r -> b f co r 1 1")
# Joint input log-likelihood
joint_input_ll = left_ll + right_ll # (B, D, 1, R, I, J)
# Compute log expectations
log_expectations = log_weights + log_grads + joint_input_ll - module_lls
log_expectations = log_expectations.logsumexp(0) # Sum over batch
# Normalize to get new log weights
flat_expectations = rearrange(log_expectations, "f co r i j -> f co r (i j)")
flat_log_weights = torch.nn.functional.log_softmax(flat_expectations, dim=-1)
new_log_weights = rearrange(
flat_log_weights,
"f co r (i j) -> f co r i j",
i=self._left_channels,
j=self._right_channels,
)
# M-step: update weights
self.log_weights = new_log_weights
# Recurse to children
if self._two_inputs:
self.inputs[0]._expectation_maximization_step(data, bias_correction=bias_correction, cache=cache)
self.inputs[1]._expectation_maximization_step(data, bias_correction=bias_correction, cache=cache)
else:
self.inputs.inputs._expectation_maximization_step(
data, bias_correction=bias_correction, cache=cache
)
[docs]
def marginalize(
self,
marg_rvs: list[int],
prune: bool = True,
cache: Cache | None = None,
) -> Optional["EinsumLayer" | Module]:
"""Marginalize out specified random variables.
Args:
marg_rvs: Random variable indices to marginalize.
prune: Whether to prune unnecessary modules.
cache: Cache for memoization.
Returns:
Marginalized module or None if fully marginalized.
"""
module_scope = self.scope
mutual_rvs = set(module_scope.query).intersection(set(marg_rvs))
# Fully marginalized
if len(mutual_rvs) == len(module_scope.query):
return None
# No overlap - return self unchanged (inputs unchanged)
if not mutual_rvs:
return self
# Partially marginalized
if self._two_inputs:
left_marg = self.inputs[0].marginalize(marg_rvs, prune=prune, cache=cache)
right_marg = self.inputs[1].marginalize(marg_rvs, prune=prune, cache=cache)
if left_marg is None and right_marg is None:
return None
elif left_marg is None:
return right_marg
elif right_marg is None:
return left_marg
else:
# Both still exist - create new EinsumLayer with marginalized children
return EinsumLayer(
inputs=[left_marg, right_marg],
out_channels=self.out_shape.channels,
num_repetitions=self.out_shape.repetitions,
)
else:
# Single input - marginalize the underlying input
marg_input = self.inputs.inputs.marginalize(marg_rvs, prune=prune, cache=cache)
if marg_input is None:
return None
# Check if we still have enough features for EinsumLayer
if marg_input.out_shape.features < 2:
return marg_input
if marg_input.out_shape.features % 2 != 0:
# Odd number of features - can't use EinsumLayer
return marg_input
return EinsumLayer(
inputs=marg_input,
out_channels=self.out_shape.channels,
num_repetitions=self.out_shape.repetitions,
)