Source code for spflow.zoo.pic.functional_sharing

"""Functional sharing utilities for Probabilistic Integral Circuits.

This module provides neural network components for functional sharing in PICs,
as described in Section 3.3 of the NeurIPS 2024 paper:
"Scaling Continuous Latent Variable Models as Probabilistic Integral Circuits"

Functional sharing reduces the number of parameters and speeds up QPC materialization
by sharing MLPs across multiple PIC units.
"""

from __future__ import annotations

from typing import Callable, List, Optional, Sequence

import torch
from einops import rearrange
from torch import Tensor, nn
from torch.nn import functional as F


[docs] class FourierFeatures(nn.Module): """Fourier feature encoding layer for positional encoding. Maps low-dimensional inputs to higher-dimensional features using random Fourier features, which helps MLPs learn high-frequency functions. From paper Eq. 5: FF : R^I → R^M Attributes: B: Random frequency matrix (not trained). scale: Frequency scaling factor. """
[docs] def __init__( self, in_features: int, out_features: int, scale: float = 1.0, ) -> None: """Initialize FourierFeatures layer. Args: in_features: Input dimension I. out_features: Output dimension M (half of final output due to sin/cos). scale: Scaling factor for frequencies. """ super().__init__() # Random frequency matrix (not trainable) self.register_buffer("B", torch.randn(in_features, out_features) * scale) self.out_dim = out_features * 2 # sin and cos
[docs] def forward(self, x: Tensor) -> Tensor: """Apply Fourier feature encoding. Args: x: Input tensor of shape (..., in_features). Returns: Tensor of shape (..., out_features * 2). """ # x @ B: (..., out_features) projected = x @ self.B # Concatenate sin and cos for each frequency return torch.cat([torch.sin(projected), torch.cos(projected)], dim=-1)
[docs] class SharedMLP(nn.Module): """Shared MLP backbone for functional sharing. Parameterizes the shared function f in functional sharing. Uses Fourier features followed by MLP layers with nonlinearity. From paper: φ^(γ) : R^I → R^M := φ_L ∘ ... ∘ φ_1 ∘ FF Attributes: fourier: FourierFeatures input encoding. layers: Sequential MLP layers. """
[docs] def __init__( self, input_dim: int, hidden_dim: int, num_layers: int = 2, activation: nn.Module = nn.SiLU(), fourier_scale: float = 1.0, ) -> None: """Initialize SharedMLP. Args: input_dim: Dimension of input (e.g., latent variable dimension). hidden_dim: Dimension of hidden layers M. num_layers: Number of hidden layers L. activation: Activation function ψ. fourier_scale: Scale for Fourier features. """ super().__init__() self.input_dim = input_dim self.hidden_dim = hidden_dim # Fourier feature encoding self.fourier = FourierFeatures(input_dim, hidden_dim, scale=fourier_scale) fourier_out = self.fourier.out_dim # MLP layers layers: List[nn.Module] = [] # First layer: Fourier output → hidden layers.append(nn.Linear(fourier_out, hidden_dim)) layers.append(activation) # Hidden layers for _ in range(num_layers - 1): layers.append(nn.Linear(hidden_dim, hidden_dim)) layers.append(activation) self.layers = nn.Sequential(*layers)
[docs] def forward(self, x: Tensor) -> Tensor: """Forward pass through shared MLP. Args: x: Input tensor of shape (..., input_dim). Returns: Hidden representation of shape (..., hidden_dim). """ # Apply Fourier features h = self.fourier(x) # Apply MLP layers return self.layers(h)
[docs] class MultiHeadedMLP(nn.Module): """Multi-headed MLP for C-sharing (composite sharing). Shares a SharedMLP backbone across multiple functions, with separate output heads for each function. This enables efficient C-sharing where fi = hi ∘ f, sharing inner function f. From paper (neural C-sharing): fi : R^M → R := softplus(h^(i) · φ^(γ) + b^(i)) Attributes: shared: SharedMLP backbone. heads: List of linear heads for each function. """
[docs] def __init__( self, shared_mlp: SharedMLP, num_heads: int, output_activation: Optional[nn.Module] = None, ) -> None: """Initialize MultiHeadedMLP. Args: shared_mlp: Shared MLP backbone. num_heads: Number of output heads N. output_activation: Activation for outputs (default: softplus for positivity). """ super().__init__() self.shared = shared_mlp self.num_heads = num_heads # Create heads: each is (h^(i), b^(i)) pair hidden_dim = shared_mlp.hidden_dim self.heads = nn.ModuleList([nn.Linear(hidden_dim, 1) for _ in range(num_heads)]) self.output_activation = output_activation or nn.Softplus()
[docs] def forward(self, x: Tensor, head_idx: Optional[int] = None) -> Tensor: """Forward pass through multi-headed MLP. Args: x: Input tensor of shape (..., input_dim). head_idx: Optional specific head index. If None, returns all heads. Returns: If head_idx is specified: Output of shape (..., 1). Otherwise: Output of shape (..., num_heads). """ # Shared backbone h = self.shared(x) # (..., hidden_dim) if head_idx is not None: # Single head output out = self.heads[head_idx](h) return self.output_activation(out) else: # All heads outputs = torch.cat([head(h) for head in self.heads], dim=-1) # (..., num_heads) return self.output_activation(outputs)
[docs] class FunctionGroup(nn.Module): """Container for grouping PIC units with functional sharing. Groups integral/input units that share the same MLP for efficient materialization. Attributes: sharing_type: Type of sharing ("f" for F-sharing, "c" for C-sharing). units: List of units in this group. mlp: Shared MLP for this group. """
[docs] def __init__( self, sharing_type: str = "c", input_dim: int = 1, hidden_dim: int = 64, num_layers: int = 2, ) -> None: """Initialize FunctionGroup. Args: sharing_type: "f" for F-sharing (all same), "c" for C-sharing (multi-headed). input_dim: Input dimension for MLP. hidden_dim: Hidden dimension for MLP. num_layers: Number of layers in MLP. """ super().__init__() if sharing_type not in {"c", "f"}: raise ValueError("sharing_type must be 'c' (C-sharing) or 'f' (F-sharing).") self.sharing_type = sharing_type self.units: list = [] # Create shared MLP self.mlp = SharedMLP( input_dim=input_dim, hidden_dim=hidden_dim, num_layers=num_layers, ) # Will be populated as units are added self._multi_headed: Optional[MultiHeadedMLP] = None self._f_head: Optional[nn.Linear] = None
[docs] def add_unit(self, unit) -> int: """Add a unit to this group. Args: unit: PIC unit (Integral or input unit). Returns: Index of the unit in this group (for C-sharing head selection). """ idx = len(self.units) self.units.append(unit) return idx
[docs] def finalize(self) -> None: """Finalize the group after all units are added. Creates the multi-headed MLP for C-sharing. """ if len(self.units) == 0: return if self.sharing_type == "c": self._multi_headed = MultiHeadedMLP( shared_mlp=self.mlp, num_heads=len(self.units), ) else: self._f_head = nn.Linear(self.mlp.hidden_dim, 1)
[docs] def evaluate_batched(self, z: Tensor, y: Tensor) -> Tensor: """Evaluate all functions in the group in a single shared-backbone pass. This implements the C-sharing/F-sharing semantics from Sec. 3.3 of the paper: - C-sharing: different heads over a shared backbone - F-sharing: a single head shared across units Args: z: Tensor with last dimension matching the z-input dimensionality. y: Tensor with last dimension matching the y-input dimensionality. `z` and `y` must be broadcastable to the same leading shape. Returns: If C-sharing: Tensor of shape (num_units, *leading_shape). If F-sharing: Tensor of shape (1, *leading_shape). """ if self.sharing_type == "c" and self._multi_headed is None: self.finalize() if self.sharing_type == "f" and self._f_head is None: self.finalize() # Broadcast z and y to a common leading shape (excluding last dim). leading_shape = torch.broadcast_shapes(z.shape[:-1], y.shape[:-1]) num_z_features = z.shape[-1] num_y_features = y.shape[-1] z_b = torch.broadcast_to(z, (*leading_shape, num_z_features)) y_b = torch.broadcast_to(y, (*leading_shape, num_y_features)) xy = torch.cat([z_b, y_b], dim=-1) flat = rearrange(xy, "... d -> (...) d") shared_h = self.mlp(flat) # (N, hidden_dim) if self.sharing_type == "c": assert self._multi_headed is not None # Reuse shared backbone representation (avoid recomputing MLP per head). # MultiHeadedMLP expects the original x; we bypass it and apply heads directly. outputs = torch.cat([head(shared_h) for head in self._multi_headed.heads], dim=-1) outputs = self._multi_headed.output_activation(outputs) # (N, num_units) outputs = rearrange(outputs, "n u -> u n") else: assert self._f_head is not None outputs = rearrange(F.softplus(self._f_head(shared_h)), "n 1 -> 1 n") return outputs.reshape(outputs.shape[0], *leading_shape)
[docs] def get_function(self, unit_idx: int = 0) -> Callable[[Tensor, Tensor], Tensor]: """Get a callable function for a specific unit/head. The returned callable preserves broadcast shapes: for broadcastable `z` and `y`, it returns a tensor with the broadcasted leading shape. Args: unit_idx: Index of the unit in this group (only used for C-sharing). Returns: Callable mapping `(z, y)` to a positive tensor. """ def _fn(z: Tensor, y: Tensor) -> Tensor: outputs = self.evaluate_batched(z, y) head = 0 if self.sharing_type == "f" else unit_idx return outputs[head] return _fn