"""Convolutional sum layer for probabilistic circuits.
Provides SumConv, which applies learned weighted sums over input channels
within spatial patches, enabling mixture modeling with spatial structure.
"""
from __future__ import annotations
import numpy as np
import torch
from torch import Tensor
from torch.nn import functional as F
from spflow.exceptions import InvalidWeightsError, MissingCacheError, ShapeError
from spflow.modules.module import Module
from spflow.modules.module_shape import ModuleShape
from spflow.utils.cache import Cache, cached
from spflow.utils.projections import proj_convex_to_real
from spflow.utils.sampling_context import SamplingContext, init_default_sampling_context
from spflow.modules.conv.utils import expand_sampling_context, upsample_sampling_context
[docs]
class SumConv(Module):
"""Convolutional sum layer for probabilistic circuits.
Applies weighted sum over input channels within spatial patches. Weights are
learned and normalized to sum to one per patch position, maintaining valid
probability distributions. Useful for modeling spatial structure in image data.
The layer expects input with spatial structure and applies shared weights
across all spatial patches of the same position within the kernel.
Attributes:
inputs (Module): Input module providing log-likelihoods.
kernel_size (int): Size of the spatial kernel (kernel_size x kernel_size).
in_channels (int): Number of input channels.
out_channels (int): Number of output channels (mixture components).
logits (Parameter): Unnormalized log-weights for gradient optimization.
"""
[docs]
def __init__(
self,
inputs: Module,
out_channels: int,
kernel_size: int,
num_repetitions: int = 1,
) -> None:
"""Create a SumConv module for spatial mixture modeling.
Args:
inputs: Input module providing log-likelihoods with spatial structure.
out_channels: Number of output mixture components.
kernel_size: Size of the spatial kernel (kernel_size x kernel_size).
num_repetitions: Number of independent repetitions.
Raises:
ValueError: If out_channels < 1 or kernel_size < 1.
"""
super().__init__()
if out_channels < 1:
raise ValueError(f"out_channels must be >= 1, got {out_channels}")
if kernel_size < 1:
raise ValueError(f"kernel_size must be >= 1, got {kernel_size}")
self.inputs = inputs
self.kernel_size = kernel_size
self.sum_dim = 1 # Sum over input channels
# Infer input shape
input_shape = self.inputs.out_shape
self.in_channels = input_shape.channels
# Scope is inherited from input (per-pixel scopes preserved)
self.scope = self.inputs.scope
# Shape computation
self.in_shape = input_shape
self.out_shape = ModuleShape(
features=input_shape.features, # Spatial dimensions unchanged
channels=out_channels,
repetitions=num_repetitions,
)
# Weight shape: (out_channels, in_channels, kernel_size, kernel_size, repetitions)
self.weights_shape = (
out_channels,
self.in_channels,
kernel_size,
kernel_size,
num_repetitions,
)
# Initialize weights uniformly
weights = torch.rand(self.weights_shape) + 1e-08
weights = weights / weights.sum(dim=self.sum_dim, keepdim=True)
# Register parameter for unnormalized log-probabilities
self.logits = torch.nn.Parameter(proj_convex_to_real(weights))
@property
def feature_to_scope(self) -> np.ndarray:
"""Per-pixel scopes are preserved from input."""
return self.inputs.feature_to_scope
@property
def log_weights(self) -> Tensor:
"""Returns the log weights normalized to sum to one over input channels.
Returns:
Tensor: Log weights of shape (out_c, in_c, k, k, reps).
"""
return F.log_softmax(self.logits, dim=self.sum_dim)
@property
def weights(self) -> Tensor:
"""Returns the weights normalized to sum to one over input channels.
Returns:
Tensor: Weights of shape (out_c, in_c, k, k, reps).
"""
return F.softmax(self.logits, dim=self.sum_dim)
@weights.setter
def weights(self, values: Tensor) -> None:
"""Set weights of all nodes.
Args:
values: Tensor containing weights.
Raises:
ValueError: If weights have invalid shape or values.
"""
if values.shape != self.weights_shape:
raise ShapeError(
f"Invalid shape for weights: Was {values.shape} but expected {self.weights_shape}."
)
if not torch.all(values > 0):
raise InvalidWeightsError("Weights must be all positive.")
if not torch.allclose(values.sum(dim=self.sum_dim), values.new_tensor(1.0)):
raise InvalidWeightsError("Weights must sum to one over input channels.")
self.logits.data = proj_convex_to_real(values)
def extra_repr(self) -> str:
return (
f"in_channels={self.in_channels}, out_channels={self.out_shape.channels}, "
f"kernel_size={self.kernel_size}"
)
[docs]
@cached
def log_likelihood(
self,
data: Tensor,
cache: Cache | None = None,
) -> Tensor:
"""Compute log likelihood using convolutional weighted sum.
Applies weighted sum over input channels within spatial patches.
Each kernel position gets its own set of mixture weights.
Uses logsumexp for numerical stability.
Args:
data: Input data of shape (batch_size, num_features).
cache: Cache for intermediate computations.
Returns:
Tensor: Log-likelihood of shape (batch, features, out_channels, reps).
"""
if cache is None:
cache = Cache()
# Get input log-likelihoods: (batch, features, in_channels, reps)
ll = self.inputs.log_likelihood(data, cache=cache)
batch_size = ll.shape[0]
num_features = ll.shape[1]
in_channels = ll.shape[2]
in_reps = ll.shape[3]
# Handle repetition matching
out_reps = self.out_shape.repetitions
if in_reps == 1 and out_reps > 1:
# Broadcast input reps
ll = ll.unsqueeze(-1).expand(-1, -1, -1, out_reps)
elif in_reps != out_reps and in_reps != 1:
raise ValueError(f"Input repetitions {in_reps} incompatible with output {out_reps}")
# Infer spatial dimensions from num_features
# Assume square spatial dimensions
H = W = int(num_features**0.5)
if H * W != num_features:
raise ValueError(
f"SumConv requires square spatial dimensions. Got {num_features} features "
f"which is not a perfect square."
)
K = self.kernel_size
# Special case: spatial dims smaller than kernel size
# Use only the first kernel weight position [0, 0]
if H < K or W < K:
# Get log weights for position [0, 0]: (out_c, in_c, reps)
log_weights = self.log_weights[:, :, 0, 0, :] # (out_c, in_c, reps)
# Reshape for broadcasting: (1, 1, out_c, in_c, reps)
log_weights = log_weights.view(1, 1, self.out_shape.channels, in_channels, out_reps)
# Reshape ll for broadcasting: (batch, features, 1, in_c, reps)
ll = ll.unsqueeze(2)
# Weighted sum over input channels: logsumexp over dim 3 (in_channels)
weighted_lls = ll + log_weights
result = torch.logsumexp(weighted_lls, dim=3) # (batch, features, out_c, reps)
return result
if H % K != 0 or W % K != 0:
raise ValueError(f"Spatial dims ({H}, {W}) must be divisible by kernel_size {K}")
# Get log weights: (out_c, in_c, k, k, reps)
log_weights = self.log_weights
# Reshape ll from (batch, features, in_c, reps) to spatial form
# (batch, in_c, H, W, reps)
ll = ll.permute(0, 2, 1, 3) # (batch, in_c, features, reps)
ll = ll.view(batch_size, in_channels, H, W, out_reps)
# Patch the input into KxK blocks
# (batch, in_c, H//K, K, W//K, K, reps)
ll = ll.view(batch_size, in_channels, H // K, K, W // K, K, out_reps)
# Reorder to (batch, in_c, H//K, W//K, K, K, reps)
ll = ll.permute(0, 1, 2, 4, 3, 5, 6)
# Make space for out_channels: (batch, 1, in_c, H//K, W//K, K, K, reps)
ll = ll.unsqueeze(1)
# Make space in log_weights for spatial dims: (1, out_c, in_c, 1, 1, K, K, reps)
log_weights = log_weights.unsqueeze(0).unsqueeze(3).unsqueeze(4)
# Weighted sum over input channels: logsumexp over dim 2 (in_channels)
weighted_lls = ll + log_weights
result = torch.logsumexp(weighted_lls, dim=2) # (batch, out_c, H//K, W//K, K, K, reps)
# Invert the patch transformation
# (batch, out_c, H//K, W//K, K, K, reps) -> (batch, out_c, H//K, K, W//K, K, reps)
result = result.permute(0, 1, 2, 4, 3, 5, 6)
# Reshape back to (batch, out_c, H, W, reps)
result = result.contiguous().view(batch_size, self.out_shape.channels, H, W, out_reps)
# Convert back to (batch, features, out_c, reps)
result = result.view(batch_size, self.out_shape.channels, num_features, out_reps)
result = result.permute(0, 2, 1, 3) # (batch, features, out_c, reps)
return result
[docs]
def sample(
self,
num_samples: int | None = None,
data: Tensor | None = None,
is_mpe: bool = False,
cache: Cache | None = None,
sampling_ctx: SamplingContext | None = None,
) -> Tensor:
"""Generate samples from sum conv module.
Each spatial position samples from its per-position kernel weights.
Args:
num_samples: Number of samples to generate.
data: Data tensor with NaN values to fill with samples.
is_mpe: Whether to perform maximum a posteriori estimation.
cache: Optional cache dictionary.
sampling_ctx: Optional sampling context.
Returns:
Tensor: Sampled values.
"""
if cache is None:
cache = Cache()
# Handle num_samples case
if data is None:
if num_samples is None:
num_samples = 1
data = torch.full((num_samples, len(self.scope.query)), float("nan")).to(self.device)
batch_size = data.shape[0]
# Initialize sampling context
sampling_ctx = init_default_sampling_context(sampling_ctx, batch_size, data.device)
num_features = self.in_shape.features
# Infer spatial dimensions
H = W = int(num_features**0.5)
K = self.kernel_size
if H * W != num_features:
raise ValueError(
f"SumConv requires square spatial dimensions. Got {num_features} features "
f"which is not a perfect square."
)
if H % K != 0 or W % K != 0:
raise ValueError(f"Spatial dims ({H}, {W}) must be divisible by kernel_size {K}")
# Expand channel_index and mask to match input features if needed
current_features = sampling_ctx.channel_index.shape[1]
if current_features != num_features:
if current_features == 1:
expand_sampling_context(sampling_ctx, num_features)
else:
# Upsample from parent spatial dims to input spatial dims
upsample_sampling_context(
sampling_ctx,
current_height=H // K,
current_width=W // K,
scale_h=K,
scale_w=K,
)
channel_idx = sampling_ctx.channel_index # (batch, H*W)
# Get logits: (out_c, in_c, k, k, reps)
logits = self.logits
# Select repetition
if sampling_ctx.repetition_idx is not None:
# logits: (out_c, in_c, k, k, reps) -> select reps
rep_idx = sampling_ctx.repetition_idx.view(-1, 1, 1, 1, 1)
rep_idx = rep_idx.expand(batch_size, logits.shape[0], logits.shape[1], K, K)
logits = logits.unsqueeze(0).expand(batch_size, -1, -1, -1, -1, -1)
logits = torch.gather(logits, dim=-1, index=rep_idx.unsqueeze(-1)).squeeze(-1)
# logits: (batch, out_c, in_c, k, k)
else:
logits = logits[..., 0] # (out_c, in_c, k, k)
logits = logits.unsqueeze(0).expand(batch_size, -1, -1, -1, -1)
# logits: (batch, out_c, in_c, k, k)
# Check for cached likelihoods (conditional sampling)
input_lls = None
if (
cache is not None
and "log_likelihood" in cache
and cache["log_likelihood"].get(self.inputs) is not None
):
input_lls = cache["log_likelihood"][self.inputs] # (batch, features, in_c, reps)
# Select repetition
if sampling_ctx.repetition_idx is not None:
rep_idx = sampling_ctx.repetition_idx.view(-1, 1, 1, 1)
rep_idx = rep_idx.expand(-1, input_lls.shape[1], input_lls.shape[2], 1)
input_lls = torch.gather(input_lls, dim=-1, index=rep_idx).squeeze(-1)
else:
input_lls = input_lls[..., 0]
# input_lls: (batch, H*W, in_c)
# Reshape to spatial: (batch, H, W, in_c)
input_lls = input_lls.view(batch_size, H, W, self.in_channels)
# Reshape channel_idx to spatial: (batch, H, W)
channel_idx = channel_idx.view(batch_size, H, W)
# Sample per-position: each pixel position needs its own sample
# Create position indices for kernel
# Position within kernel: pixel (i, j) has kernel pos (i % K, j % K)
row_pos = torch.arange(H, device=data.device).view(1, H, 1).expand(batch_size, H, W)
col_pos = torch.arange(W, device=data.device).view(1, 1, W).expand(batch_size, H, W)
k_row = row_pos % K # (batch, H, W)
k_col = col_pos % K # (batch, H, W)
# logits: (batch, out_c, in_c, k, k)
# Select logits for each position based on parent channel and kernel position
# First gather by parent channel: (batch, H, W, in_c, k, k)
logits_per_pos = logits.permute(0, 3, 4, 1, 2) # (batch, k, k, out_c, in_c)
# Index by kernel position
# k_row, k_col: (batch, H, W)
# Need to gather from (batch, k, k, out_c, in_c)
# Flatten kernel dims for gathering
logits_per_pos = logits_per_pos.view(batch_size, K * K, self.out_shape.channels, self.in_channels)
# Compute flat kernel index
k_flat = k_row * K + k_col # (batch, H, W)
# Expand for gathering: (batch, H, W, out_c, in_c)
k_flat_exp = (
k_flat.unsqueeze(-1).unsqueeze(-1).expand(-1, -1, -1, self.out_shape.channels, self.in_channels)
)
logits_per_pos_exp = logits_per_pos.unsqueeze(2).unsqueeze(3).expand(-1, -1, H, W, -1, -1)
# (batch, K*K, H, W, out_c, in_c) -> swap dims for gather
logits_per_pos_exp = logits_per_pos_exp.permute(0, 2, 3, 1, 4, 5) # (batch, H, W, K*K, out_c, in_c)
# Gather by k_flat
k_flat_exp2 = k_flat.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) # (batch, H, W, 1, 1, 1)
k_flat_exp2 = k_flat_exp2.expand(-1, -1, -1, -1, self.out_shape.channels, self.in_channels)
selected_logits = torch.gather(logits_per_pos_exp, dim=3, index=k_flat_exp2).squeeze(3)
# selected_logits: (batch, H, W, out_c, in_c)
# Now select by parent channel: channel_idx (batch, H, W)
parent_ch = channel_idx.unsqueeze(-1).unsqueeze(-1) # (batch, H, W, 1, 1)
parent_ch = parent_ch.expand(-1, -1, -1, -1, self.in_channels) # (batch, H, W, 1, in_c)
selected_logits = torch.gather(selected_logits, dim=3, index=parent_ch).squeeze(3)
# selected_logits: (batch, H, W, in_c)
# Compute posterior if we have cached likelihoods
if input_lls is not None:
# input_lls: (batch, H, W, in_c)
log_posterior = selected_logits + input_lls
log_posterior = F.log_softmax(log_posterior, dim=-1)
else:
log_posterior = F.log_softmax(selected_logits, dim=-1)
# Sample for each position
log_posterior_flat = log_posterior.view(-1, self.in_channels)
if is_mpe:
sampled_channels_flat = torch.argmax(log_posterior_flat, dim=-1)
else:
sampled_channels_flat = torch.distributions.Categorical(logits=log_posterior_flat).sample()
sampled_channels = sampled_channels_flat.view(batch_size, H, W)
sampled_channels = sampled_channels.view(batch_size, num_features)
# Update sampling context
sampling_ctx.channel_index = sampled_channels
# Sample from input
self.inputs.sample(
data=data,
is_mpe=is_mpe,
cache=cache,
sampling_ctx=sampling_ctx,
)
return data
[docs]
def expectation_maximization(
self,
data: Tensor,
bias_correction: bool = True,
cache: Cache | None = None,
) -> None:
"""Perform expectation-maximization step to update weights.
Follows the standard EM update pattern for sum nodes:
1. Get cached log-likelihoods for input and this module
2. Compute expectations using: log_weights + log_grads + input_lls - module_lls
3. Normalize to get new log_weights
Args:
data: Input data tensor.
bias_correction: Whether to apply bias correction (unused currently).
cache: Cache dictionary with log-likelihoods from forward pass.
Raises:
MissingCacheError: If required log-likelihoods are not found in cache.
"""
if cache is None:
cache = Cache()
with torch.no_grad():
# Get cached log-likelihoods
input_lls = cache["log_likelihood"].get(self.inputs)
if input_lls is None:
raise MissingCacheError(
"Input log-likelihoods not found in cache. Call log_likelihood first."
)
module_lls = cache["log_likelihood"].get(self)
if module_lls is None:
raise MissingCacheError(
"Module log-likelihoods not found in cache. Call log_likelihood first."
)
# input_lls shape: (batch, features, in_channels, reps)
# module_lls shape: (batch, features, out_channels, reps)
# log_weights shape: (out_channels, in_channels, k, k, reps)
batch_size = input_lls.shape[0]
num_features = input_lls.shape[1]
in_channels = input_lls.shape[2]
out_channels = module_lls.shape[2]
num_reps = self.out_shape.repetitions
# Get log gradients from module output
# grad is set during backward pass or EM routine
if module_lls.grad is None:
# If no gradient, use uniform (this happens at the root)
log_grads = torch.zeros_like(module_lls)
else:
log_grads = torch.log(module_lls.grad + 1e-10)
# Current log weights: (out_c, in_c, k, k, reps)
# Average over kernel spatial dims for simplicity
log_weights = self.log_weights.mean(dim=(2, 3)) # (out_c, in_c, reps)
# Reshape for broadcasting:
# log_weights: (1, 1, out_c, in_c, reps)
# log_grads: (batch, features, out_c, 1, reps)
# input_lls: (batch, features, 1, in_c, reps)
# module_lls: (batch, features, out_c, 1, reps)
log_weights = log_weights.unsqueeze(0).unsqueeze(0) # (1, 1, out_c, in_c, reps)
log_grads = log_grads.unsqueeze(3) # (batch, features, out_c, 1, reps)
input_lls = input_lls.unsqueeze(2) # (batch, features, 1, in_c, reps)
module_lls = module_lls.unsqueeze(3) # (batch, features, out_c, 1, reps)
# Compute log expectations
# This follows the standard EM derivation for mixture models
log_expectations = log_weights + log_grads + input_lls - module_lls
# Shape: (batch, features, out_c, in_c, reps)
# Sum over batch and features dimensions
log_expectations = torch.logsumexp(log_expectations, dim=0) # (features, out_c, in_c, reps)
log_expectations = torch.logsumexp(log_expectations, dim=0) # (out_c, in_c, reps)
# Normalize over in_channels (sum dimension for this module)
# The sum_dim for SumConv is dimension 1 (in_channels)
log_expectations = torch.log_softmax(log_expectations, dim=1)
# Update log_weights: need to expand to full kernel shape
# Current shape: (out_c, in_c, reps)
# Target shape: (out_c, in_c, k, k, reps)
k = self.kernel_size
new_log_weights = log_expectations.unsqueeze(2).unsqueeze(3) # (out_c, in_c, 1, 1, reps)
new_log_weights = new_log_weights.expand(-1, -1, k, k, -1) # (out_c, in_c, k, k, reps)
# Set new weights
self.logits.data = new_log_weights.contiguous()
# Recursively call EM on inputs
self.inputs.expectation_maximization(data, cache=cache, bias_correction=bias_correction)
[docs]
def marginalize(
self,
marg_rvs: list[int],
prune: bool = True,
cache: Cache | None = None,
) -> SumConv | Module | None:
"""Marginalize out specified random variables.
Args:
marg_rvs: List of random variable indices to marginalize.
prune: Whether to prune unnecessary nodes.
cache: Optional cache for storing intermediate results.
Returns:
SumConv | Module | None: Marginalized module or None if fully marginalized.
"""
# Compute scope intersection
layer_scope = self.scope
mutual_rvs = set(layer_scope.query).intersection(set(marg_rvs))
# Fully marginalized
if len(mutual_rvs) == len(layer_scope.query):
return None
# Marginalize input
marg_input = self.inputs.marginalize(marg_rvs, prune=prune, cache=cache)
if marg_input is None:
return None
# For now, return a new SumConv with marginalized input
# Note: This is a simplified implementation
return SumConv(
inputs=marg_input,
out_channels=self.out_shape.channels,
kernel_size=self.kernel_size,
num_repetitions=self.out_shape.repetitions,
)