from __future__ import annotations
from typing import Optional
import numpy as np
import torch
from torch import Tensor, nn
from spflow.exceptions import (
InvalidParameterCombinationError,
InvalidWeightsError,
MissingCacheError,
ScopeError,
ShapeError,
)
from spflow.meta.data import Scope
from spflow.modules.module import Module
from spflow.modules.module_shape import ModuleShape
from spflow.utils.cache import Cache, cached
from spflow.utils.projections import (
proj_convex_to_real,
)
from spflow.utils.sampling_context import SamplingContext, init_default_sampling_context
[docs]
class ElementwiseSum(Module):
"""Elementwise sum operation for mixture modeling.
Computes weighted combinations of input tensors element-wise. Weights
are automatically normalized to sum to one. Uses log-domain computations.
Attributes:
logits (Parameter): Unnormalized log-weights for gradient optimization.
unraveled_channel_indices (Tensor): Mapping for flattened channel indices.
"""
[docs]
def __init__(
self,
inputs: list[Module],
out_channels: int | None = None,
weights: Tensor | None = None,
num_repetitions: int | None = None,
) -> None:
"""Initialize elementwise sum module.
Args:
inputs: Input modules (same features, compatible channels).
out_channels: Number of output nodes per sum. Note that this results in a total of
out_channels * in_channels (input modules) output channels since we sum over the list of
modules.
weights: Initial weights (if None, randomly initialized).
num_repetitions: Number of repetitions.
"""
super().__init__()
# ========== 1. INPUT VALIDATION ==========
if not inputs:
raise ValueError("'Sum' requires at least one input to be specified.")
# ========== 2. WEIGHTS PARAMETER PROCESSING ==========
if weights is not None:
# Validate mutual exclusivity
if out_channels is not None:
raise InvalidParameterCombinationError(
f"Cannot specify both 'out_channels' and 'weights' for 'Sum' module."
)
if num_repetitions is not None:
raise InvalidParameterCombinationError(
f"Cannot specify both 'num_repetitions' and 'weights' for 'Sum' module."
)
if weights.dim() != 5:
raise ShapeError(
f"Weights for 'ElementwiseSum' must be a 5D tensor but was {weights.dim()}D."
)
# Derive configuration from weights shape
out_channels = weights.shape[2]
inferred_num_repetitions = weights.shape[4]
if num_repetitions is not None and num_repetitions != inferred_num_repetitions:
raise InvalidParameterCombinationError(
f"Cannot specify 'num_repetitions' that does not match weights shape for 'Sum' module. "
f"Was {num_repetitions} but weights shape indicates {inferred_num_repetitions}."
)
num_repetitions = inferred_num_repetitions
else:
# Set defaults when weights not provided
if out_channels is None:
raise ValueError(f"Either 'out_channels' or 'weights' must be specified for 'Sum' module.")
if num_repetitions is None:
num_repetitions = 1
# ========== 3. CONFIGURATION VALIDATION ==========
if out_channels < 1:
raise ValueError(
f"Number of nodes for 'Sum' must be greater of equal to 1 but was {out_channels}."
)
# Validate all inputs have the same number of features
if not all([module.out_shape.features == inputs[0].out_shape.features for module in inputs]):
raise ShapeError("All inputs must have the same number of features.")
# Validate all inputs have compatible channels (same or 1 for broadcasting)
if not all(
[module.out_shape.channels in (1, max(m.out_shape.channels for m in inputs)) for module in inputs]
):
raise ShapeError(
"All inputs must have compatible channels: same number of channels or 1 channel (in which "
"case the operation is broadcast)."
)
# Validate all input modules have the same scope
if not Scope.all_equal([module.scope for module in inputs]):
raise ScopeError("All input modules must have the same scope.")
# Validate for each repetition that modules have the same features_to_scope mapping
for rep in range(num_repetitions):
feature_to_scope = inputs[0].feature_to_scope[..., rep]
for module in inputs[1:]:
if not np.array_equal(feature_to_scope, module.feature_to_scope[..., rep]):
raise ScopeError(
"All input modules must have the same feature to scope mapping for each repetition."
)
# ========== 4. INPUT MODULE SETUP ==========
self.inputs = nn.ModuleList(inputs)
self.sum_dim = 3
self.scope = inputs[0].scope
# ========== 5. SHAPE COMPUTATION (early, so shapes can be reused below) ==========
in_channels = max(module.out_shape.channels for module in self.inputs)
out_channels_total = out_channels * in_channels
self._num_sums = out_channels # Store for use in sampling
self.in_shape = ModuleShape(inputs[0].out_shape.features, in_channels, num_repetitions)
self.out_shape = ModuleShape(self.in_shape.features, out_channels_total, num_repetitions)
# ========== 6. WEIGHT INITIALIZATION & PARAMETER REGISTRATION ==========
self.weights_shape = (
self.in_shape.features,
self.in_shape.channels,
out_channels,
len(inputs),
self.out_shape.repetitions,
)
# Register unraveled channel indices for mapping flattened indices to (channel, sum) pairs
# E.g. for 3 in_channels and 2 out_channels: [0,1,2,3,4,5] -> [(0,0), (0,1), (0,2), (1,0), (1,1), (1,2)]
unraveled_channel_indices = torch.tensor(
[(i, j) for i in range(self.in_shape.channels) for j in range(self._num_sums)],
dtype=torch.long,
)
self.register_buffer(name="unraveled_channel_indices", tensor=unraveled_channel_indices)
if weights is None:
# Initialize weights randomly with small epsilon to avoid zeros
weights = torch.rand(self.weights_shape) + 1e-08
# Normalize to sum to one along sum_dim
weights /= torch.sum(weights, dim=self.sum_dim, keepdim=True)
# Register parameter for unnormalized log-probabilities
self.logits = torch.nn.Parameter(torch.zeros(self.weights_shape))
# Set weights (converts to logits internally via property setter)
self.weights = weights
@property
def feature_to_scope(self) -> np.ndarray:
return self.inputs[0].feature_to_scope
@property
def log_weights(self) -> Tensor:
# project auxiliary weights onto weights that sum up to one
return torch.nn.functional.log_softmax(self.logits, dim=self.sum_dim)
@property
def weights(self) -> Tensor:
# project auxiliary weights onto weights that sum up to one
return torch.nn.functional.softmax(self.logits, dim=self.sum_dim)
@weights.setter
def weights(
self,
values: Tensor,
) -> None:
"""Set weights of all nodes.
Args:
values: Weight values to set.
"""
if values.shape != self.weights_shape:
raise ShapeError(f"Invalid shape for weights: {values.shape}.")
if not torch.all(values > 0):
raise InvalidWeightsError("Weights for 'Sum' must be all positive.")
if not torch.allclose(torch.sum(values, dim=self.sum_dim), values.new_tensor(1.0)):
raise InvalidWeightsError("Weights for 'Sum' must sum up to one.")
self.logits.data = proj_convex_to_real(values)
@log_weights.setter
def log_weights(
self,
values: Tensor,
) -> None:
"""Set log weights of all nodes.
Args:
values: Log weight values to set.
"""
if values.shape != self.log_weights.shape:
raise ValueError(f"Invalid shape for weights: {values.shape}.")
self.logits.data = values
def extra_repr(self) -> str:
return f"{super().extra_repr()}, weights={self.weights_shape}"
[docs]
def marginalize(
self,
marg_rvs: list[int],
prune: bool = True,
cache: Cache | None = None,
) -> Optional["ElementwiseSum"]:
"""Marginalize out specified random variables.
Args:
marg_rvs: Random variables to marginalize out.
prune: Whether to prune the resulting module.
cache: Cache for memoization.
Returns:
Optional[ElementwiseSum]: Marginalized module or None if fully marginalized.
"""
# initialize cache
if cache is None:
cache = Cache()
# compute module scope (same for all outputs)
module_scope = self.scope
marg_input = None
mutual_rvs = set(module_scope.query).intersection(set(marg_rvs))
module_weights = self.weights
# module scope is being fully marginalized over
if len(mutual_rvs) == len(module_scope.query):
# passing this loop means marginalizing over the whole scope of this branch
pass
# node scope is being partially marginalized
elif mutual_rvs:
# marginalize input modules
marg_input = [inp.marginalize(marg_rvs, prune=prune, cache=cache) for inp in self.inputs]
if all(mi is None for mi in marg_input):
marg_input = None
# if marginalized input is not None
if marg_input:
indices = [self.scope.query.index(el) for el in list(mutual_rvs)]
mask = torch.ones(len(module_scope.query), device=module_weights.device, dtype=torch.bool)
mask[indices] = False
module_weights = module_weights[mask]
else:
marg_input = self.inputs
if marg_input is None:
return None
else:
return ElementwiseSum(inputs=[inp for inp in marg_input], weights=module_weights)
[docs]
def sample(
self,
num_samples: int | None = None,
data: Tensor | None = None,
is_mpe: bool = False,
cache: Cache | None = None,
sampling_ctx: Optional[SamplingContext] = None,
) -> Tensor:
"""Generate samples by choosing mixture components.
Args:
num_samples: Number of samples to generate.
data: Existing data tensor to fill with samples.
is_mpe: Whether to perform most probable explanation.
cache: Cache for memoization.
sampling_ctx: Sampling context for conditional sampling.
Returns:
Tensor: Generated samples.
"""
# Prepare data tensor
data = self._prepare_sample_data(num_samples, data)
# initialize contexts
if cache is None:
cache = Cache()
sampling_ctx = init_default_sampling_context(sampling_ctx, data.shape[0])
# Index into the correct weight channels given by parent module
# (stay in logits space since Categorical distribution accepts logits directly)
if sampling_ctx.repetition_idx is not None:
logits = self.logits.unsqueeze(0).expand(sampling_ctx.channel_index.shape[0], -1, -1, -1, -1, -1)
indices = sampling_ctx.repetition_idx # Shape (30000, 1, 1)
# Use gather to select the correct repetition
# Repeat indices to match the target dimension for gathering
indices = indices.view(-1, 1, 1, 1, 1, 1).expand(
-1, logits.shape[1], logits.shape[2], logits.shape[3], logits.shape[4], -1
)
logits = torch.gather(logits, dim=-1, index=indices).squeeze(-1)
else:
if self.out_shape.repetitions > 1:
raise ValueError(
"sampling_ctx.repetition_idx must be provided when sampling from a module with "
"num_repetitions > 1."
)
logits = self.logits[..., 0] # Select the 0th repetition
logits = logits.unsqueeze(0) # Make space for the batch
# Expand to batch size
logits = logits.expand(sampling_ctx.channel_index.shape[0], -1, -1, -1, -1)
cids_mapped = self.unraveled_channel_indices[sampling_ctx.channel_index]
# Take the first element of the tuple (input_channel_idx, output_channel_idx)
# This is the out_channels index for all inputs in the Stack module
cids_in_channels_per_input = cids_mapped[..., 0]
cids_num_sums = cids_mapped[..., 1]
# Index weights with cids_num_sums (selects the correct output channel)
cids_num_sums = cids_num_sums[..., None, None, None].expand(
-1, -1, logits.shape[-3], -1, logits.shape[-1]
)
logits = logits.gather(dim=3, index=cids_num_sums).squeeze(3)
# Index logits with oids_in_channels_per_input to get the correct logits for each input
logits = logits.gather(
dim=2, index=cids_in_channels_per_input[..., None, None].expand(-1, -1, -1, logits.shape[-1])
).squeeze(2)
if (
cache is not None
and "log_likelihood" in cache
and all(cache["log_likelihood"][inp] is not None for inp in self.inputs)
):
input_lls = [cache["log_likelihood"][inp] for inp in self.inputs]
input_lls = torch.stack(input_lls, dim=self.sum_dim) # torch.stack(input_lls, dim=-1)
if sampling_ctx.repetition_idx is not None:
indices = sampling_ctx.repetition_idx.view(-1, 1, 1, 1, 1).expand(
-1, input_lls.shape[1], input_lls.shape[2], input_lls.shape[3], -1
)
input_lls = torch.gather(input_lls, dim=-1, index=indices).squeeze(-1)
is_conditional = True
else:
is_conditional = False
if is_conditional:
cids_in_channels_input_lls = (
cids_in_channels_per_input.unsqueeze(2).unsqueeze(3).expand(-1, -1, -1, input_lls.shape[3])
)
input_lls = input_lls.gather(dim=2, index=cids_in_channels_input_lls).squeeze(2)
# Compute log posterior by reweighing logits with input lls
log_prior = logits
log_posterior = log_prior + input_lls
log_posterior = log_posterior.log_softmax(dim=2)
logits = log_posterior
# Sample/MPE from categorical distribution defined by weights to obtain indices into the Stack dimension
if is_mpe:
cids_stack = torch.argmax(logits, dim=-1)
else:
cids_stack = torch.distributions.Categorical(logits=logits).sample()
# Sample from input module
sampling_ctx.channel_index = cids_in_channels_per_input
for i, inp in enumerate(self.inputs):
# Update feature_mask
mask = sampling_ctx.mask & (cids_stack == i)
sampling_ctx_cpy = sampling_ctx.copy()
sampling_ctx_cpy.mask = mask
# Sample from input module
inp.sample(
data=data,
is_mpe=is_mpe,
cache=cache,
sampling_ctx=sampling_ctx_cpy,
)
return data
[docs]
@cached
def log_likelihood(
self,
data: Tensor,
cache: Cache | None = None,
) -> Tensor:
"""Compute log likelihood via weighted log-sum-exp.
Args:
data: Input data tensor.
cache: Cache for memoization.
Returns:
Tensor: Computed log likelihood values.
"""
# Get input log-likelihoods
lls = []
for inp in self.inputs:
ll = inp.log_likelihood(
data,
cache=cache,
)
# Prepare for broadcasting
if inp.out_shape.channels == 1 and self.in_shape.channels > 1:
ll = ll.expand(
data.shape[0], self.out_shape.features, self.in_shape.channels, self.out_shape.repetitions
)
lls.append(ll)
# Stack input log-likelihoods
stacked_lls = torch.stack(lls, dim=self.sum_dim)
ll = stacked_lls.unsqueeze(3) # shape: (B, F, IC, 1)
log_weights = self.log_weights.unsqueeze(0) # shape: (1, F, IC, OC)
# Weighted log-likelihoods
weighted_lls = ll + log_weights # shape: (B, F, IC, OC)
# Sum over input channels (sum_dim + 1 since here the batch dimension is the first dimension)
output = torch.logsumexp(weighted_lls, dim=self.sum_dim + 1)
output = output.view(
data.shape[0], self.out_shape.features, self.out_shape.channels, self.out_shape.repetitions
)
return output
[docs]
def expectation_maximization(
self,
data: Tensor,
cache: Cache | None = None,
) -> None:
"""Perform EM step to update mixture weights.
Args:
data: Training data tensor.
cache: Cache for memoization.
"""
# initialize cache
if cache is None:
cache = Cache()
with torch.no_grad():
# ----- expectation step -----
# Get input LLs
input_lls = []
for inp in self.inputs:
inp_ll = cache.get("log_likelihood", inp)
if inp_ll is None:
raise MissingCacheError("Input log-likelihoods not found in cache.")
input_lls.append(inp_ll)
input_lls = torch.stack(input_lls, dim=3)
# Get module lls
module_lls = cache.get("log_likelihood", self)
if module_lls is None:
raise MissingCacheError("Module log-likelihood not found in cache.")
log_weights = self.log_weights.unsqueeze(0)
input_lls = input_lls.unsqueeze(3)
# Get input channel indices
s = (
module_lls.shape[0],
self.out_shape.features,
self.in_shape.channels,
self._num_sums,
1,
)
if self.out_shape.repetitions is not None:
s = s + (self.out_shape.repetitions,)
log_grads = torch.log(module_lls.grad).view(s)
module_lls = module_lls.view(s)
log_expectations = log_weights + log_grads + input_lls - module_lls
log_expectations = log_expectations.logsumexp(0) # Sum over batch dimension
log_expectations = log_expectations.log_softmax(self.sum_dim) # Normalize
# ----- maximization step -----
self.log_weights = log_expectations
for inp in self.inputs:
inp.expectation_maximization(data, cache=cache)
[docs]
def maximum_likelihood_estimation(
self,
data: Tensor,
weights: Optional[Tensor] = None,
cache: Cache | None = None,
) -> None:
"""MLE step (equivalent to EM for sum nodes).
Args:
data: Training data tensor.
weights: Optional weights for data points.
cache: Cache for memoization.
"""
self.expectation_maximization(data, cache=cache)