from __future__ import annotations
import numpy as np
import torch
from torch import Tensor
from spflow.exceptions import (
InvalidParameterCombinationError,
InvalidWeightsError,
MissingCacheError,
ShapeError,
)
from spflow.modules.module import Module
from spflow.modules.module_shape import ModuleShape
from spflow.modules.ops.cat import Cat
from spflow.utils.cache import Cache, cached
from spflow.utils.projections import (
proj_convex_to_real,
)
from spflow.utils.sampling_context import SamplingContext, init_default_sampling_context
[docs]
class Sum(Module):
"""Sum module representing mixture operations in probabilistic circuits.
Implements mixture modeling by computing weighted combinations of child distributions.
Weights are normalized to sum to one, maintaining valid probability distributions.
Supports both single input (mixture over channels) and multiple inputs (mixture
over concatenated inputs).
Attributes:
inputs (Module): Input module(s) to the sum node.
sum_dim (int): Dimension over which to sum the inputs.
weights (Tensor): Normalized weights for mixture components.
logits (Parameter): Unnormalized log-weights for gradient optimization.
"""
[docs]
def __init__(
self,
inputs: Module | list[Module],
out_channels: int | None = None,
num_repetitions: int = 1,
weights: Tensor | list[float] | None = None,
) -> None:
"""Create a Sum module for mixture modeling.
Weights are automatically normalized to sum to one using softmax.
Multiple inputs are concatenated along dimension 2 internally.
Args:
inputs (Module | list[Module]): Single module or list of modules to mix.
out_channels (int | None, optional): Number of output mixture components.
Required if weights not provided.
num_repetitions (int | None, optional): Number of repetitions for structured
representations. Inferred from weights if not provided.
weights (Tensor | list[float] | None, optional): Initial mixture weights.
Must have compatible shape with inputs and out_channels.
Raises:
ValueError: If inputs empty, out_channels < 1, or weights have invalid shape/values.
InvalidParameterCombinationError: If both out_channels and weights are specified.
"""
super().__init__()
# ========== 1. INPUT VALIDATION ==========
if not inputs:
raise ValueError("'Sum' requires at least one input to be specified.")
if weights is not None and isinstance(weights, list):
weights = torch.as_tensor(weights, dtype=torch.get_default_dtype())
weights, out_channels_inferred, num_repetitions_inferred = self._process_weights_parameter(
inputs=inputs,
weights=weights,
out_channels=out_channels,
num_repetitions=num_repetitions,
)
# Use inferred values
out_channels = out_channels_inferred
if num_repetitions_inferred is not None:
num_repetitions = num_repetitions_inferred
if num_repetitions is None:
num_repetitions = 1
# ========== 3. CONFIGURATION VALIDATION ==========
if out_channels is None or out_channels < 1:
raise ValueError(
f"Number of nodes for 'Sum' must be greater of equal to 1 but was {out_channels}."
)
# ========== 4. INPUT MODULE SETUP ==========
if isinstance(inputs, list):
if len(inputs) == 1:
self.inputs = inputs[0]
else:
self.inputs = Cat(inputs=inputs, dim=2)
else:
self.inputs = inputs
self.sum_dim = 1
self.scope = self.inputs.scope
# ========== 5. SHAPE COMPUTATION (early, so shapes can be reused below) ==========
self.in_shape = self.inputs.out_shape
self.out_shape = ModuleShape(
features=self.in_shape.features, channels=out_channels, repetitions=num_repetitions
)
# ========== 6. WEIGHT INITIALIZATION & PARAMETER REGISTRATION ==========
self.weights_shape = self._get_weights_shape()
weights = self._initialize_weights(weights)
# Register parameter for unnormalized log-probabilities
self.logits = torch.nn.Parameter()
# Set weights (converts to logits internally via property setter)
self.weights = weights
def _process_weights_parameter(
self,
inputs: Module | list[Module],
weights: Tensor | None,
out_channels: int | None,
num_repetitions: int | None,
) -> tuple[Tensor | None, int | None, int | None]:
if weights is None:
return weights, out_channels, num_repetitions
if out_channels is not None:
raise InvalidParameterCombinationError(
f"Cannot specify both 'out_channels' and 'weights' for 'Sum' module."
)
weight_dim = weights.dim()
if weight_dim == 1:
weights = weights.view(1, -1, 1, 1)
elif weight_dim == 2:
weights = weights.view(1, weights.shape[0], weights.shape[1], 1)
elif weight_dim == 3:
weights = weights.unsqueeze(-1)
elif weight_dim == 4:
pass
else:
raise ShapeError(f"Weights for 'Sum' must be a 1D, 2D, 3D, or 4D tensor but was {weight_dim}D.")
inferred_num_repetitions = weights.shape[-1]
if num_repetitions is not None and (
num_repetitions != 1 and num_repetitions != inferred_num_repetitions
):
raise InvalidParameterCombinationError(
f"Cannot specify 'num_repetitions' that does not match weights shape for 'Sum' module. "
f"Was {num_repetitions} but weights shape indicates {inferred_num_repetitions}."
)
num_repetitions = inferred_num_repetitions
out_channels = weights.shape[2]
return weights, out_channels, num_repetitions
def _get_weights_shape(self) -> tuple[int, int, int, int]:
return (
self.in_shape.features,
self.in_shape.channels,
self.out_shape.channels,
self.out_shape.repetitions,
)
def _initialize_weights(self, weights: Tensor | None) -> Tensor:
if weights is None:
weights = torch.rand(self.weights_shape) + 1e-08
weights /= torch.sum(weights, dim=self.sum_dim, keepdims=True)
return weights
@property
def feature_to_scope(self) -> np.ndarray:
return self.inputs.feature_to_scope
@property
def log_weights(self) -> Tensor:
"""Returns the log weights of all nodes as a tensor.
Returns:
Tensor: Log weights normalized to sum to one.
"""
# project auxiliary weights onto weights that sum up to one
return torch.nn.functional.log_softmax(self.logits, dim=self.sum_dim)
@property
def weights(self) -> Tensor:
"""Returns the weights of all nodes as a tensor.
Returns:
Tensor: Weights normalized to sum to one.
"""
# project auxiliary weights onto weights that sum up to one
return torch.nn.functional.softmax(self.logits, dim=self.sum_dim)
@weights.setter
def weights(
self,
values: Tensor,
) -> None:
"""Set weights of all nodes.
Args:
values: Tensor containing weights for each input and node.
Raises:
ShapeError: If weights have invalid shape.
ValueError: If weights contain non-positive values, or do not sum to one.
"""
if values.shape != self.weights_shape:
raise ShapeError(
f"Invalid shape for weights: Was {values.shape} but expected {self.weights_shape}."
)
if not torch.all(values > 0):
raise InvalidWeightsError("Weights for 'Sum' must be all positive.")
if not torch.allclose(torch.sum(values, dim=self.sum_dim), values.new_tensor(1.0)):
raise InvalidWeightsError("Weights for 'Sum' must sum up to one.")
self.logits.data = proj_convex_to_real(values)
@log_weights.setter
def log_weights(
self,
values: Tensor,
) -> None:
"""Set log weights of all nodes.
Args:
values: Tensor containing log weights for each input and node.
Raises:
ShapeError: If log weights have invalid shape.
"""
if values.shape != self.log_weights.shape:
raise ShapeError(f"Invalid shape for weights: {values.shape}.")
self.logits.data = values
def extra_repr(self) -> str:
return f"{super().extra_repr()}, weights={self.weights_shape}"
[docs]
@cached
def log_likelihood(
self,
data: Tensor,
cache: Cache | None = None,
) -> Tensor:
"""Compute log likelihood P(data | module).
Computes log likelihood using logsumexp for numerical stability.
Results are cached for parameter learning algorithms.
Args:
data: Input data of shape (batch_size, num_features).
NaN values indicate evidence for conditional computation.
cache: Cache for intermediate computations. Defaults to None.
Returns:
Tensor: Log-likelihood of shape (batch_size, num_features, out_channels)
or (batch_size, num_features, out_channels, num_repetitions).
"""
if cache is None:
cache = Cache()
# Get input log-likelihoods
ll = self.inputs.log_likelihood(
data,
cache=cache,
)
ll = ll.unsqueeze(3) # shape: (B, F, input_OC, R)
log_weights = self.log_weights.unsqueeze(0) # shape: (1, F, IC, OC, R)
# Weighted log-likelihoods
weighted_lls = ll + log_weights # shape: (B, F, IC, OC, R)
# Sum over input channels (sum_dim + 1 since here the batch dimension is the first dimension)
output = torch.logsumexp(weighted_lls, dim=self.sum_dim + 1)
batch_size = output.shape[0]
result = output.view(
batch_size, self.out_shape.features, self.out_shape.channels, self.out_shape.repetitions
)
return result
[docs]
def sample(
self,
num_samples: int | None = None,
data: Tensor | None = None,
is_mpe: bool = False,
cache: Cache | None = None,
sampling_ctx: SamplingContext | None = None,
) -> Tensor:
"""Generate samples from sum module.
Args:
num_samples: Number of samples to generate.
data: Data tensor with NaN values to fill with samples.
is_mpe: Whether to perform maximum a posteriori estimation.
cache: Optional cache dictionary.
sampling_ctx: Optional sampling context.
Returns:
Tensor: Sampled values.
"""
if cache is None:
cache = Cache()
# Handle num_samples case (create empty data tensor)
if data is None:
if num_samples is None:
num_samples = 1
data = torch.full((num_samples, len(self.scope.query)), float("nan")).to(self.device)
# Initialize sampling context if not provided
sampling_ctx = init_default_sampling_context(sampling_ctx, data.shape[0], data.device)
# Index into the correct weight channels given by parent module
if sampling_ctx.repetition_idx is not None:
logits = self.logits.unsqueeze(0).expand(
sampling_ctx.channel_index.shape[0], -1, -1, -1, -1
) # shape [b , n_features , in_c, out_c, r]
indices = sampling_ctx.repetition_idx # Shape (batch, features)
# Use gather to select the correct repetition
# Repeat indices to match the target dimension for gathering
in_channels_total = logits.shape[2]
indices = indices.view(-1, 1, 1, 1, 1).expand(
-1, logits.shape[1], in_channels_total, logits.shape[3], -1
)
# Gather the logits based on the repetition indices
logits = torch.gather(logits, dim=-1, index=indices).squeeze(-1)
else:
if self.out_shape.repetitions > 1:
raise ValueError(
"sampling_ctx.repetition_idx must be provided when sampling from a module with "
"num_repetitions > 1."
)
logits = self.logits[..., 0] # Select the 0th repetition
logits = logits.unsqueeze(0) # Make space for the batch
# Expand to batch size
logits = logits.expand(sampling_ctx.channel_index.shape[0], -1, -1, -1)
idxs = sampling_ctx.channel_index[..., None, None]
in_channels_total = logits.shape[2]
idxs = idxs.expand(-1, -1, in_channels_total, -1)
# Gather the logits based on the channel indices
logits = logits.gather(dim=3, index=idxs).squeeze(3)
# Check if evidence is given (cached log-likelihoods)
if (
cache is not None
and "log_likelihood" in cache
and cache["log_likelihood"].get(self.inputs) is not None
):
# Get the log likelihoods from the cache
input_lls = cache["log_likelihood"][self.inputs]
if sampling_ctx.repetition_idx is not None:
indices = sampling_ctx.repetition_idx.view(-1, 1, 1, 1).expand(
-1, input_lls.shape[1], input_lls.shape[2], -1
)
# Use gather to select the correct repetition
input_lls = torch.gather(input_lls, dim=-1, index=indices).squeeze(-1)
log_prior = logits
log_posterior = log_prior + input_lls
log_posterior = log_posterior.log_softmax(dim=2)
logits = log_posterior
else:
# When no repetition_idx, squeeze the repetitions dimension of input_lls
if input_lls.dim() == 4 and input_lls.shape[-1] == 1:
input_lls = input_lls.squeeze(-1)
log_prior = logits
log_posterior = log_prior + input_lls
log_posterior = log_posterior.log_softmax(dim=2)
logits = log_posterior
# Sample from categorical distribution defined by weights to obtain indices into input channels
if is_mpe:
# Take the argmax of the logits to obtain the most probable index
new_channel_index = torch.argmax(logits, dim=-1)
else:
# Sample from categorical distribution defined by weights to obtain indices into input channels
new_channel_index = torch.distributions.Categorical(logits=logits).sample()
# Update sampling context with new channel indices
# If shape changes, expand the mask to match new channel_index shape
if new_channel_index.shape != sampling_ctx.mask.shape:
# Expand mask from (batch, 1) or (batch, old_features) to (batch, new_features)
new_mask = sampling_ctx.mask.expand_as(new_channel_index).contiguous()
if new_mask.shape != new_channel_index.shape:
# Fall back to creating a full True mask
new_mask = torch.ones(
new_channel_index.shape, dtype=torch.bool, device=new_channel_index.device
)
sampling_ctx.update(new_channel_index, new_mask)
else:
sampling_ctx.channel_index = new_channel_index
# Sample from input module
self.inputs.sample(
data=data,
is_mpe=is_mpe,
cache=cache,
sampling_ctx=sampling_ctx,
)
return data
[docs]
def expectation_maximization(
self,
data: Tensor,
bias_correction: bool = True,
cache: Cache | None = None,
) -> None:
"""Perform expectation-maximization step.
Args:
data: Input data tensor.
cache: Optional cache dictionary with log-likelihoods.
bias_correction: Whether to apply bias correction.
Raises:
MissingCacheError: If required log-likelihoods are not found in cache.
"""
if cache is None:
cache = Cache()
with torch.no_grad():
# ----- expectation step -----
# Get input LLs from cache
input_lls = cache["log_likelihood"].get(self.inputs)
if input_lls is None:
raise MissingCacheError(
"Input log-likelihoods not found in cache. Call log_likelihood first."
)
# Get module lls from cache
module_lls = cache["log_likelihood"].get(self)
if module_lls is None:
raise MissingCacheError(
"Module log-likelihoods not found in cache. Call log_likelihood first."
)
log_weights = self.log_weights.unsqueeze(0)
log_grads = torch.log(module_lls.grad).unsqueeze(2)
input_lls = input_lls.unsqueeze(3)
module_lls = module_lls.unsqueeze(2)
log_expectations = log_weights + log_grads + input_lls - module_lls
log_expectations = log_expectations.logsumexp(0) # Sum over batch dimension
log_expectations = log_expectations.log_softmax(self.sum_dim) # Normalize
# ----- maximization step -----
self.log_weights = log_expectations
# Recursively call EM on inputs
self.inputs.expectation_maximization(data, cache=cache, bias_correction=bias_correction)
[docs]
def maximum_likelihood_estimation(
self,
data: Tensor,
weights: Tensor | None = None,
cache: Cache | None = None,
) -> None:
"""Update parameters via maximum likelihood estimation.
For Sum modules, this is equivalent to EM.
Args:
data: Input data tensor.
weights: Optional sample weights (currently unused).
cache: Optional cache dictionary.
"""
self.expectation_maximization(data, cache=cache)
[docs]
def marginalize(
self,
marg_rvs: list[int],
prune: bool = True,
cache: Cache | None = None,
) -> Sum | None:
"""Marginalize out specified random variables.
Args:
marg_rvs: List of random variables to marginalize.
prune: Whether to prune the module.
cache: Optional cache dictionary.
Returns:
Marginalized Sum module or None.
"""
if cache is None:
cache = Cache()
# compute module scope (same for all outputs)
module_scope = self.scope
marg_input = None
mutual_rvs = set(module_scope.query).intersection(set(marg_rvs))
module_weights = self.weights
# module scope is being fully marginalized over
if len(mutual_rvs) == len(module_scope.query):
# passing this loop means marginalizing over the whole scope of this branch
return None
# node scope is being partially marginalized
elif mutual_rvs:
# marginalize input modules
marg_input = self.inputs.marginalize(marg_rvs, prune=prune, cache=cache)
# if marginalized input is not None
if marg_input:
# Apply mask to weights per-repetition
masked_weights_list = []
for r in range(self.out_shape.repetitions):
feature_to_scope_r = self.inputs.feature_to_scope[:, r].copy()
# remove mutual_rvs from feature_to_scope list
for rv in mutual_rvs:
for idx, scope in enumerate(feature_to_scope_r):
if scope is not None:
if rv in scope.query:
feature_to_scope_r[idx] = scope.remove_from_query(rv)
# construct mask with empty scopes
mask = torch.tensor(
[not scope.empty() for scope in feature_to_scope_r], device=self.device
).bool()
# Apply mask to weights for this repetition: (out_features, in_channels, out_channels)
masked_weights_r = module_weights[:, :, :, r][mask]
masked_weights_list.append(masked_weights_r)
# Stack weights back along the repetition dimension
# Handle different repetition counts if needed
if all(w.shape[0] == masked_weights_list[0].shape[0] for w in masked_weights_list):
# All repetitions have same number of features, can stack directly
module_weights = torch.stack(masked_weights_list, dim=-1)
else:
# Features differ across repetitions - this shouldn't happen in practice
# but handle gracefully by keeping the largest
max_features = max(w.shape[0] for w in masked_weights_list)
padded_list = []
for w in masked_weights_list:
if w.shape[0] < max_features:
padding = torch.zeros(
max_features - w.shape[0],
w.shape[1],
w.shape[2],
device=w.device,
dtype=w.dtype,
)
w = torch.cat([w, padding], dim=0)
padded_list.append(w)
module_weights = torch.stack(padded_list, dim=-1)
else:
marg_input = self.inputs
if marg_input is None:
return None
else:
return Sum(inputs=marg_input, weights=module_weights)