Source code for spflow.modules.conv.prod_conv

"""Convolutional product layer for probabilistic circuits.

Provides ProdConv, which computes products over spatial patches (sums in log-space),
reducing spatial dimensions while aggregating scopes.
"""

from __future__ import annotations

import numpy as np
import torch
from torch import Tensor
from torch.nn import functional as F

from spflow.meta.data.scope import Scope
from spflow.modules.module import Module
from spflow.modules.module_shape import ModuleShape
from spflow.utils.cache import Cache, cached
from spflow.utils.sampling_context import SamplingContext, init_default_sampling_context
from spflow.modules.conv.utils import expand_sampling_context, upsample_sampling_context


[docs] class ProdConv(Module): """Convolutional product layer for probabilistic circuits. Computes products over spatial patches, reducing spatial dimensions by the kernel size factor. This is equivalent to summing log-likelihoods within patches. No learnable parameters. Scopes are aggregated per patch: a 2×2 patch containing Scope(0), Scope(1), Scope(2), Scope(3) produces Scope([0,1,2,3]). Attributes: inputs (Module): Input module providing log-likelihoods. kernel_size_h (int): Kernel height. kernel_size_w (int): Kernel width. padding_h (int): Padding in height dimension. padding_w (int): Padding in width dimension. """
[docs] def __init__( self, inputs: Module, kernel_size_h: int, kernel_size_w: int, padding_h: int = 0, padding_w: int = 0, ) -> None: """Create a ProdConv module for spatial product operations. Args: inputs: Input module providing log-likelihoods with spatial structure. kernel_size_h: Height of the pooling kernel. kernel_size_w: Width of the pooling kernel. padding_h: Padding in height dimension (added on both sides). padding_w: Padding in width dimension (added on both sides). Raises: ValueError: If kernel sizes are < 1. """ super().__init__() if kernel_size_h < 1: raise ValueError(f"kernel_size_h must be >= 1, got {kernel_size_h}") if kernel_size_w < 1: raise ValueError(f"kernel_size_w must be >= 1, got {kernel_size_w}") self.inputs = inputs self.kernel_size_h = kernel_size_h self.kernel_size_w = kernel_size_w self.padding_h = padding_h self.padding_w = padding_w # Infer input shape input_shape = self.inputs.out_shape in_features = input_shape.features # Infer spatial dimensions (assumes square for now, can be extended) in_h = in_w = int(np.sqrt(in_features)) assert in_h * in_w == in_features, f"Features {in_features} must be a perfect square" # Compute output spatial dimensions padded_h = in_h + 2 * padding_h padded_w = in_w + 2 * padding_w out_h = padded_h // kernel_size_h out_w = padded_w // kernel_size_w self._input_h = in_h self._input_w = in_w self._output_h = out_h self._output_w = out_w # Shape computation self.in_shape = input_shape self.out_shape = ModuleShape( features=out_h * out_w, channels=input_shape.channels, repetitions=input_shape.repetitions, ) # Compute aggregated scope # Each output position covers a kernel_size_h x kernel_size_w patch self._compute_scope()
def _compute_scope(self) -> None: """Compute the aggregated scope for this layer.""" # Get input scopes input_scopes = self.inputs.feature_to_scope # For now, compute overall scope as union of all input scopes all_rvs = set() for scope in input_scopes.flatten(): all_rvs.update(scope.query) self.scope = Scope(sorted(all_rvs)) @property def feature_to_scope(self) -> np.ndarray: """Aggregated scopes per output feature. Each output feature's scope is the join of all input scopes within its patch. Returns: np.ndarray: 2D array of Scope objects (features, repetitions). """ input_f2s = self.inputs.feature_to_scope # (in_features, reps) in_h, in_w = self._input_h, self._input_w out_h, out_w = self._output_h, self._output_w kh, kw = self.kernel_size_h, self.kernel_size_w ph, pw = self.padding_h, self.padding_w num_reps = input_f2s.shape[1] result = np.empty((out_h * out_w, num_reps), dtype=object) for r in range(num_reps): # Reshape input scopes to spatial: (in_h, in_w) input_scopes_2d = input_f2s[:, r].reshape(in_h, in_w) out_idx = 0 for oh in range(out_h): for ow in range(out_w): # Compute input patch bounds start_h = oh * kh - ph end_h = start_h + kh start_w = ow * kw - pw end_w = start_w + kw # Collect scopes from valid positions (not padding) patch_scopes = [] for ih in range(max(0, start_h), min(in_h, end_h)): for iw in range(max(0, start_w), min(in_w, end_w)): patch_scopes.append(input_scopes_2d[ih, iw]) # Join all scopes in the patch if patch_scopes: result[out_idx, r] = Scope.join_all(patch_scopes) else: # Edge case: entire patch is padding result[out_idx, r] = Scope([]) out_idx += 1 return result def extra_repr(self) -> str: return ( f"kernel=({self.kernel_size_h}, {self.kernel_size_w}), " f"padding=({self.padding_h}, {self.padding_w})" )
[docs] @cached def log_likelihood( self, data: Tensor, cache: Cache | None = None, ) -> Tensor: """Compute log likelihood by summing within patches. Uses depthwise convolution with ones kernel to efficiently sum log-probabilities within patches. Args: data: Input data of shape (batch_size, num_features). cache: Cache for intermediate computations. Returns: Tensor: Log-likelihood of shape (batch, out_features, channels, reps). """ if cache is None: cache = Cache() # Get input log-likelihoods: (batch, features, channels, reps) ll = self.inputs.log_likelihood(data, cache=cache) batch_size = ll.shape[0] in_features = ll.shape[1] channels = ll.shape[2] reps = ll.shape[3] in_h, in_w = self._input_h, self._input_w out_h, out_w = self._output_h, self._output_w kh, kw = self.kernel_size_h, self.kernel_size_w ph, pw = self.padding_h, self.padding_w # Reshape to spatial: (batch, channels, H, W, reps) ll = ll.permute(0, 2, 1, 3) # (batch, channels, features, reps) ll = ll.view(batch_size, channels, in_h, in_w, reps) # Merge batch and reps for efficient conv: (batch * reps, channels, H, W) ll = ll.permute(0, 4, 1, 2, 3) # (batch, reps, channels, H, W) ll = ll.contiguous().view(batch_size * reps, channels, in_h, in_w) # Apply depthwise convolution with ones kernel # This sums values within each patch ones_kernel = torch.ones(channels, 1, kh, kw, device=ll.device, dtype=ll.dtype) result = F.conv2d(ll, ones_kernel, stride=(kh, kw), padding=(ph, pw), groups=channels) # Reshape back: (batch, reps, channels, out_H, out_W) result = result.view(batch_size, reps, channels, out_h, out_w) # Convert to SPFlow format: (batch, out_features, channels, reps) result = result.permute(0, 3, 4, 2, 1) # (batch, out_H, out_W, channels, reps) result = result.contiguous().view(batch_size, out_h * out_w, channels, reps) return result
[docs] def sample( self, num_samples: int | None = None, data: Tensor | None = None, is_mpe: bool = False, cache: Cache | None = None, sampling_ctx: SamplingContext | None = None, ) -> Tensor: """Generate samples by delegating to input. ProdConv has no learnable parameters, so sampling simply expands the sampling context to match input features and delegates. Args: num_samples: Number of samples to generate. data: Data tensor with NaN values to fill with samples. is_mpe: Whether to perform maximum a posteriori estimation. cache: Optional cache dictionary. sampling_ctx: Optional sampling context. Returns: Tensor: Sampled values. """ if cache is None: cache = Cache() # Handle num_samples case if data is None: if num_samples is None: num_samples = 1 data = torch.full((num_samples, len(self.scope.query)), float("nan")).to(self.device) batch_size = data.shape[0] # Initialize sampling context sampling_ctx = init_default_sampling_context(sampling_ctx, batch_size, data.device) # Expand channel_index and mask to match input features in_features = self.in_shape.features out_features = self.out_shape.features current_features = sampling_ctx.channel_index.shape[1] # If channel_index is smaller than in_features, expand it if current_features < in_features: if current_features == 1: # Broadcast single value to all input features expand_sampling_context(sampling_ctx, in_features) elif current_features == out_features: # Upsample from output to input resolution upsample_sampling_context( sampling_ctx, current_height=self._output_h, current_width=self._output_w, scale_h=self.kernel_size_h, scale_w=self.kernel_size_w, ) # Handle padding case: trim to actual input size if sampling_ctx.channel_index.shape[1] > in_features: channel_idx = sampling_ctx.channel_index[:, :in_features].contiguous() mask = sampling_ctx.mask[:, :in_features].contiguous() sampling_ctx.update(channel_index=channel_idx, mask=mask) else: # Just broadcast first element expand_sampling_context(sampling_ctx, in_features) # Sample from input self.inputs.sample( data=data, is_mpe=is_mpe, cache=cache, sampling_ctx=sampling_ctx, ) return data
[docs] def expectation_maximization( self, data: Tensor, bias_correction: bool = True, cache: Cache | None = None, ) -> None: """EM step (delegates to input, no learnable parameters). Args: data: Input data tensor for EM step. bias_correction: Whether to apply bias correction. cache: Optional cache for storing intermediate results. """ # Product has no learnable parameters, delegate to input self.inputs.expectation_maximization(data, cache=cache, bias_correction=bias_correction)
[docs] def marginalize( self, marg_rvs: list[int], prune: bool = True, cache: Cache | None = None, ) -> ProdConv | Module | None: """Marginalize out specified random variables. Args: marg_rvs: List of random variable indices to marginalize. prune: Whether to prune unnecessary nodes. cache: Optional cache for storing intermediate results. Returns: ProdConv | Module | None: Marginalized module or None if fully marginalized. """ # Compute scope intersection layer_scope = self.scope mutual_rvs = set(layer_scope.query).intersection(set(marg_rvs)) # Fully marginalized if len(mutual_rvs) == len(layer_scope.query): return None # Marginalize input marg_input = self.inputs.marginalize(marg_rvs, prune=prune, cache=cache) if marg_input is None: return None # For now, return a new ProdConv with marginalized input # Note: This is a simplified implementation return ProdConv( inputs=marg_input, kernel_size_h=self.kernel_size_h, kernel_size_w=self.kernel_size_w, padding_h=self.padding_h, padding_w=self.padding_w, )