"""Convolutional Probabilistic Circuit.
Provides ConvPc, a multi-layer architecture that stacks alternating
SumConv and ProdConv layers on top of a leaf distribution.
"""
from __future__ import annotations
import math
import numpy as np
import torch
from torch import Tensor
from spflow.modules.conv.prod_conv import ProdConv
from spflow.modules.conv.sum_conv import SumConv
from spflow.modules.module import Module
from spflow.modules.module_shape import ModuleShape
from spflow.modules.sums import Sum
from spflow.modules.sums.repetition_mixing_layer import RepetitionMixingLayer
from spflow.utils.cache import Cache, cached
from spflow.utils.sampling_context import SamplingContext
def compute_non_overlapping_kernel_and_padding(
H_data: int, W_data: int, H_target: int, W_target: int
) -> tuple[tuple[int, int], tuple[int, int]]:
"""Compute kernel size and padding for non-overlapping convolution.
Computes kernel size and padding such that a single F.conv2d with
stride=kernel_size and dilation=1 transforms the input to the target size.
Args:
H_data: Input height.
W_data: Input width.
H_target: Target output height.
W_target: Target output width.
Returns:
Tuple of (kernel_size, padding) where:
kernel_size: (kH, kW)
padding: (pH, pW)
Raises:
ValueError: If any dimension is non-positive.
"""
if H_data <= 0 or W_data <= 0 or H_target <= 0 or W_target <= 0:
raise ValueError("All dimensions must be positive.")
# Compute required kernel sizes
kH = math.ceil(H_data / H_target)
kW = math.ceil(W_data / W_target)
# Compute padding needed to make input + 2*padding divisible by kernel
padded_H = kH * H_target
padded_W = kW * W_target
total_pad_H = max(padded_H - H_data, 0)
total_pad_W = max(padded_W - W_data, 0)
pH = total_pad_H // 2
pW = total_pad_W // 2
return (kH, kW), (pH, pW)
[docs]
class ConvPc(Module):
"""Convolutional Probabilistic Circuit.
Builds a multi-layer circuit with alternating ProdConv and SumConv layers
on top of a leaf distribution. The architecture progressively reduces
spatial dimensions while learning mixture weights at each level.
The layer ordering is: Leaf -> ProdConv -> SumConv -> ProdConv -> SumConv -> ... -> Root Sum
Layers are constructed top-down (from root to leaves), then reversed for
proper bottom-up evaluation order.
Attributes:
leaf (Module): Leaf distribution module.
root (Sum): Final sum layer producing scalar output per sample.
"""
[docs]
def __init__(
self,
leaf: Module,
input_height: int,
input_width: int,
channels: int,
depth: int,
kernel_size: int = 2,
num_repetitions: int = 1,
use_sum_conv: bool = False,
) -> None:
"""Create a ConvPc for image modeling.
Args:
leaf: Leaf distribution module (e.g., Normal over pixels).
input_height: Height of input image.
input_width: Width of input image.
channels: Number of channels per sum layer.
depth: Number of (ProdConv, SumConv) layer pairs.
kernel_size: Kernel size for pooling (default 2x2).
num_repetitions: Number of independent repetitions.
use_sum_conv: If True, use SumConv layers with kernel-based spatial
weights. If False (default), use regular Sum layers that treat
features independently without spatial awareness.
Raises:
ValueError: If depth < 1.
"""
super().__init__()
self.use_sum_conv = use_sum_conv
if depth < 1:
raise ValueError(f"depth must be >= 1, got {depth}")
if channels < 1:
raise ValueError(f"channels must be >= 1, got {channels}")
self.input_height = input_height
self.input_width = input_width
self.kernel_size = kernel_size
self.depth = depth
# Build layers top-down: start from root (1x1) and work down
# Top-down order: Sum (root) -> ProdConv -> SumConv -> ProdConv -> SumConv -> ... -> Leaf
# We'll build a list of layer specs top-down, then reverse and construct bottom-up
layer_specs = []
# Add root sum layer: 1x1 spatial, reduces channels to 1
layer_specs.append(("sum_root", {"out_channels": 1}))
# Build from top (1x1) down to target spatial size
# Each depth level adds: ProdConv (spatial expansion) -> SumConv (channel mixing)
h, w = 1, 1
for i in reversed(range(depth)):
# ProdConv expands spatial dims by kernel_size
layer_specs.append(
(
"prod",
{"kernel_size": kernel_size},
)
)
# SumConv mixes channels at this spatial level
h, w = h * kernel_size, w * kernel_size
layer_specs.append(
(
"sum",
{
"out_channels": channels,
"kernel_size": kernel_size,
},
)
)
# Now h, w is the target spatial size at the bottom of the conv layers
# We need a final ProdConv to reduce from input_height/width to h/w
# Compute kernel and padding for the bottom ProdConv layer
(kh, kw), (ph, pw) = compute_non_overlapping_kernel_and_padding(
H_data=input_height,
W_data=input_width,
H_target=h,
W_target=w,
)
layer_specs.append(
(
"prod_bottom",
{"kernel_size_h": kh, "kernel_size_w": kw, "padding_h": ph, "padding_w": pw},
)
)
# Reverse the specs so we build bottom-up (leaf -> ... -> root)
layer_specs = list(reversed(layer_specs))
# Now construct layers bottom-up, connecting via .inputs
# Bottom-up order: Leaf -> ProdConv -> SumConv -> ProdConv -> SumConv -> ... -> Sum (root)
current_input = leaf
for layer_type, params in layer_specs:
if layer_type == "prod_bottom":
current_input = ProdConv(
inputs=current_input,
kernel_size_h=params["kernel_size_h"],
kernel_size_w=params["kernel_size_w"],
padding_h=params.get("padding_h", 0),
padding_w=params.get("padding_w", 0),
)
elif layer_type == "sum":
if self.use_sum_conv:
current_input = SumConv(
inputs=current_input,
out_channels=params["out_channels"],
kernel_size=params["kernel_size"],
num_repetitions=num_repetitions,
)
else:
current_input = Sum(
inputs=current_input,
out_channels=params["out_channels"],
num_repetitions=num_repetitions,
)
elif layer_type == "prod":
current_input = ProdConv(
inputs=current_input,
kernel_size_h=params["kernel_size"],
kernel_size_w=params["kernel_size"],
)
elif layer_type == "sum_root":
# Final root sum to produce single output (1x1 spatial, 1 channel)
current_input = Sum(
inputs=current_input,
out_channels=params["out_channels"],
num_repetitions=num_repetitions,
)
self.inputs = current_input
# Add repetition mixing layer if num_repetitions > 1
if num_repetitions > 1:
self.inputs = RepetitionMixingLayer(
inputs=self.inputs,
out_channels=1,
num_repetitions=num_repetitions,
)
# Scope and shape
self.scope = leaf.scope
leaf_shape = leaf.out_shape
self.in_shape = leaf_shape
self.out_shape = ModuleShape(
features=1,
channels=1,
repetitions=1, # Always 1 after mixing layer
)
@property
def feature_to_scope(self) -> np.ndarray:
"""Single output feature with full scope."""
return self.inputs.feature_to_scope
def extra_repr(self) -> str:
return (
f"input=({self.input_height}, {self.input_width}), "
f"depth={self.depth}, kernel_size={self.kernel_size}"
)
[docs]
@cached
def log_likelihood(
self,
data: Tensor,
cache: Cache | None = None,
) -> Tensor:
"""Compute log likelihood through all layers.
Args:
data: Input data of shape (batch_size, num_pixels).
cache: Cache for intermediate computations.
Returns:
Tensor: Log-likelihood of shape (batch, 1, 1, reps).
"""
if cache is None:
cache = Cache()
# Forward through root, which recursively calls inputs
# Chain: root -> SumConv -> ProdConv -> ... -> leaf
return self.inputs.log_likelihood(data, cache=cache)
[docs]
def sample(
self,
num_samples: int | None = None,
data: Tensor | None = None,
is_mpe: bool = False,
cache: Cache | None = None,
sampling_ctx: SamplingContext | None = None,
) -> Tensor:
"""Generate samples by sampling top-down through layers.
Delegates sampling to the root module (RepetitionMixingLayer when
num_repetitions > 1, or Sum when num_repetitions == 1), which then
recursively propagates sampling to the leaf.
Args:
num_samples: Number of samples to generate.
data: Data tensor with NaN values to fill with samples.
is_mpe: Whether to perform maximum a posteriori estimation.
cache: Optional cache dictionary.
sampling_ctx: Optional sampling context.
Returns:
Tensor: Sampled values of shape (num_samples, num_pixels).
"""
if cache is None:
cache = Cache()
# Handle num_samples case
if data is None:
if num_samples is None:
num_samples = 1
data = torch.full((num_samples, len(self.scope.query)), float("nan")).to(self.device)
# Delegate to root (RepetitionMixingLayer or Sum)
# which handles channel/repetition sampling internally
self.inputs.sample(
data=data,
is_mpe=is_mpe,
cache=cache,
sampling_ctx=sampling_ctx,
)
return data
[docs]
def expectation_maximization(
self,
data: Tensor,
bias_correction: bool = True,
cache: Cache | None = None,
) -> None:
"""Perform EM update throughout the circuit.
Args:
data: Input data tensor.
bias_correction: Whether to apply bias correction.
cache: Optional cache with log-likelihoods.
"""
if cache is None:
cache = Cache()
# EM on root (which chains to all layers)
self.inputs.expectation_maximization(data, cache=cache, bias_correction=bias_correction)
[docs]
def marginalize(
self,
marg_rvs: list[int],
prune: bool = True,
cache: Cache | None = None,
) -> ConvPc | Module | None:
"""Marginalize out specified random variables.
Args:
marg_rvs: List of random variable indices to marginalize.
prune: Whether to prune unnecessary nodes.
cache: Optional cache for storing intermediate results.
Returns:
ConvPc | Module | None: Marginalized module or None if fully marginalized.
"""
# For ConvPc, marginalization is complex due to the layered architecture
# Delegate to root which handles it recursively
return self.inputs.marginalize(marg_rvs, prune=prune, cache=cache)