Developer Guide¶
This guide shows how to extend SPFlow by implementing custom modules. By the end, you’ll understand how to create:
Leaf modules — distributions at the input layer
Sum modules — weighted mixtures of child distributions
Product modules — factorizations via conditional independence
Split modules — utilities for partitioning inputs
Prerequisites: Familiarity with PyTorch
nn.Moduleand basic probability theory.
1. Core Concepts¶
1.1 Module Hierarchy¶
All SPFlow modules inherit from Module, which extends torch.nn.Module:
nn.Module
└── Module (abstract base)
├── LeafModule (distributions)
├── Sum (weighted mixtures)
├── Product (factorization)
└── Split (partitioning)
1.2 Shape System¶
SPFlow uses a 3-tuple ModuleShape(features, channels, repetitions) to describe tensor dimensions:
Dimension |
Meaning |
|---|---|
Features |
Number of scope partitions (random variable groupings) |
Channels |
Parallel distributions per feature |
Repetitions |
Independent copies of the structure |
All intermediate tensors have shape: (batch, features, channels, repetitions).
1.3 Required Interface¶
Leaf Modules |
Intermediate Modules |
|---|---|
|
|
|
|
|
|
|
|
1.4 Sampling Architecture: Top-Down Index Propagation¶
SPFlow uses ancestral sampling with a unique top-down index propagation strategy. Understanding this is crucial for implementing custom modules correctly.
Key insight: Internal nodes (Sum, Product) don’t generate samples—they only update routing indices. Only leaf nodes actually sample from distributions and write values to the output tensor.
The sampling process works as follows:
A
datatensor filled withNaNis passed through the entire circuitA
SamplingContexttracks which path to follow through the DAGSum nodes select which child channel to sample via their weights
Product nodes expand the context to cover all input features
Leaf nodes generate samples and write them in-place to
data
Sampling Flow Example:
Root (Sum)
├── samples from Categorical(weights) to pick child index
└── updates sampling_ctx.channel_index with selected child
↓
Product
├── expands channel_index from (batch, 1) to (batch, num_features)
└── passes expanded context to child
↓
Leaf (Normal)
├── uses channel_index to select which channel's parameters
├── uses repetition_idx to select which repetition's parameters
└── writes samples to data[:, self.scope.query] in-place
This design is efficient (no intermediate tensor allocation) and correct (consistent paths through the circuit).
1.5 The SamplingContext Class¶
The SamplingContext class manages routing state during sampling. It contains:
Field |
Shape |
Purpose |
|---|---|---|
|
|
Which output channel to use at each position |
|
|
Boolean mask—which positions need sampling |
|
|
Which repetition to use (for multi-repetition circuits) |
Why ``channel_index``?
Sum modules have multiple output channels (mixture components)
During sampling, we must pick exactly one path through the DAG
Parent Sum nodes set
channel_index[sample_i, feature_j]= index of selected childChildren use this to gather the correct logits/parameters
Why ``repetition_idx``?
Circuits with
num_repetitions > 1have parallel independent copiesRepetitionMixingLayerselects which repetition to use per-sampleLeaves use it to index their 3D parameter tensors
(features, channels, repetitions)
When implementing a custom module’s sample() method, you typically:
Initialize context:
sampling_ctx = init_default_sampling_context(sampling_ctx, batch_size, device)Use current indices to select parameters/weights
Update
channel_indexand/ormaskfor childrenCall
self.inputs.sample(data=data, sampling_ctx=sampling_ctx, ...)
2. Implementing a Leaf Module¶
Leaf modules wrap probability distributions. The base class LeafModule handles most functionality—you only need to define:
Distribution parameters as
nn.ParameterA
params()method returning a dict of parametersThe PyTorch distribution class to use
MLE estimation logic (for parameter learning)
Example: NoisyNormal¶
A Normal distribution that adds noise to log-likelihoods during training. This demonstrates extending a standard distribution with training-time regularization.
[10]:
import torch
from torch import Tensor, nn
from spflow.modules.leaves.leaf import LeafModule
from spflow.utils.leaves import init_parameter
class NoisyNormal(LeafModule):
"""Normal distribution with additive noise during training.
Adds Gaussian noise to log-likelihoods during training for regularization.
Deterministic during evaluation.
"""
def __init__(self, scope, out_channels=None, num_repetitions=1,
parameter_fn=None, validate_args=True, loc=None, scale=None,
noise_std: float = 0.1):
super().__init__(
scope=scope, out_channels=out_channels,
num_repetitions=num_repetitions, params=[loc, scale],
parameter_fn=parameter_fn, validate_args=validate_args,
)
# Initialize loc and scale (scale stored in log-space for positivity)
loc = init_parameter(loc, self._event_shape, init=torch.zeros)
scale = init_parameter(scale, self._event_shape, init=torch.ones)
self.loc = nn.Parameter(loc)
self.log_scale = nn.Parameter(torch.log(scale))
self.noise_std = noise_std
@property
def scale(self):
return torch.exp(self.log_scale)
@property
def _supported_value(self):
return 0.0 # Mean is always in support
@property
def _torch_distribution_class(self):
return torch.distributions.Normal
def params(self):
return {"loc": self.loc, "scale": self.scale}
def log_likelihood(self, data, cache=None):
# Overwrite LeafModule implementation to add noise during training
# Call LeafModule implementation
ll = super().log_likelihood(data, cache=cache)
# Add noise during training only
if self.training:
noise = torch.randn_like(ll) * self.noise_std
ll = ll + noise
return ll
def _compute_parameter_estimates(self, data, weights, bias_correction):
# MLE for Normal: loc = weighted mean, scale = weighted std
n = weights.sum(dim=0)
mean = (weights * data).sum(dim=0) / n
var = (weights * (data - mean) ** 2).sum(dim=0) / n
return {"loc": mean, "scale": torch.sqrt(var + 1e-8)}
def _set_mle_parameters(self, params_dict):
self.loc.data = params_dict["loc"]
self.log_scale.data = torch.log(params_dict["scale"])
[11]:
# Quick test
from spflow.meta import Scope
leaf = NoisyNormal(scope=Scope([0]), out_channels=3, noise_std=0.5)
# Training mode: noise added
leaf.train()
data = torch.randn(5, 1)
ll1 = leaf.log_likelihood(data)
ll2 = leaf.log_likelihood(data)
print(f"Training: outputs differ = {not torch.allclose(ll1, ll2)}")
# Eval mode: deterministic
leaf.eval()
ll1 = leaf.log_likelihood(data)
ll2 = leaf.log_likelihood(data)
print(f"Eval: outputs identical = {torch.allclose(ll1, ll2)}")
Training: outputs differ = True
Eval: outputs identical = True
Key points:
Store constrained parameters in transformed space (e.g.,
log_scalefor positivescale)init_parameter()handles shape inference fromout_channelsOverride
log_likelihood()to add custom behavior while callingsuper()
How Leaf Sampling Works:
Leaves are the only modules that actually generate samples. The sampling flow is:
Receive
datatensor withNaNat positions to sampleUse
sampling_ctx.channel_indexto select which channel’s parametersUse
sampling_ctx.repetition_idxto select which repetition’s parametersSample from the distribution (or take mode for MPE)
Write samples in-place to
data[:, self.scope.query]
3. Implementing a Sum Module¶
Sum modules compute weighted mixtures: \(p(x) = \sum_i w_i \cdot p_i(x)\).
We now implement a custom sum module from scratch to demonstrate what is necessary to extend SPFlow with new sum-like operations. Our NoisySum adds noise during training for regularization:
To implement any sum module, you need:
Weight parameters (stored as logits for unconstrained optimization)
log_likelihood()using logsumexp for numerical stabilitysample()that selects input channels based on weightsfeature_to_scopeproperty mapping features to scopes
Example: NoisySum¶
A sum module that adds Gaussian noise during training.
[12]:
import numpy as np
from spflow.modules.module import Module
from spflow.modules.module_shape import ModuleShape
from spflow.modules.ops.cat import Cat
from spflow.utils.cache import Cache, cached
from spflow.utils.projections import proj_convex_to_real
from spflow.utils.sampling_context import SamplingContext, init_default_sampling_context
class NoisySum(Module):
"""Sum module with additive noise during training.
Adds Gaussian noise to log-likelihoods during training for regularization.
Deterministic during evaluation.
"""
def __init__(self, inputs, out_channels: int, noise_std: float = 0.1, num_repetitions: int = 1):
super().__init__()
# Handle single module or list of modules
if isinstance(inputs, list):
self.inputs = Cat(inputs, dim=2) if len(inputs) > 1 else inputs[0]
else:
self.inputs = inputs
self.scope = self.inputs.scope
self.noise_std = noise_std
# Shape computation
in_shape = self.inputs.out_shape
self.in_shape = in_shape
self.out_shape = ModuleShape(in_shape.features, out_channels, num_repetitions)
# Weight shape: (features, in_channels, out_channels, repetitions)
self._weights_shape = (
in_shape.features, in_shape.channels, out_channels, num_repetitions
)
# Initialize weights randomly (store as logits for unconstrained optimization)
weights = torch.rand(self._weights_shape) + 1e-8
weights /= weights.sum(dim=1, keepdim=True)
self.logits = nn.Parameter(proj_convex_to_real(weights))
@property
def feature_to_scope(self) -> np.ndarray:
return self.inputs.feature_to_scope
@property
def log_weights(self) -> Tensor:
"""Log-normalized weights via log_softmax."""
return torch.nn.functional.log_softmax(self.logits, dim=1)
@property
def weights(self) -> Tensor:
"""Normalized weights via softmax."""
return torch.nn.functional.softmax(self.logits, dim=1)
@cached
def log_likelihood(self, data: Tensor, cache: Cache | None = None) -> Tensor:
# Input shape: (batch, features, in_channels, reps)
ll = self.inputs.log_likelihood(data, cache=cache)
# Expand for out_channels: (batch, features, in_channels, 1, reps)
ll = ll.unsqueeze(3)
# Weights: (1, features, in_channels, out_channels, reps)
log_w = self.log_weights.unsqueeze(0)
# Weighted sum via logsumexp over in_channels (dim=2)
result = torch.logsumexp(ll + log_w, dim=2)
# Add noise during training only
if self.training:
noise = torch.randn_like(result) * self.noise_std
result = result + noise
return result # Shape: (batch, features, out_channels, reps)
def sample(
self,
num_samples: int | None = None,
data: Tensor | None = None,
is_mpe: bool = False,
cache: Cache | None = None,
sampling_ctx: SamplingContext | None = None,
) -> Tensor:
data = self._prepare_sample_data(num_samples, data)
sampling_ctx = init_default_sampling_context(sampling_ctx, data.shape[0], data.device)
logits = self.logits[..., 0]
logits = logits.unsqueeze(0).expand(data.shape[0], -1, -1, -1)
# Gather logits for selected out_channels
idxs = sampling_ctx.channel_index.unsqueeze(-1).unsqueeze(-1)
idxs = idxs.expand(-1, -1, logits.shape[2], -1)
logits = logits.gather(dim=3, index=idxs).squeeze(3)
# Select input channels: either argmax (MPE) or sample
if is_mpe:
new_channels = logits.argmax(dim=-1)
else:
new_channels = torch.distributions.Categorical(logits=logits).sample()
sampling_ctx.channel_index = new_channels
self.inputs.sample(data=data, is_mpe=is_mpe, cache=cache, sampling_ctx=sampling_ctx)
return data
def marginalize(self, marg_rvs: list[int], prune: bool = True, cache=None):
mutual = set(self.scope.query) & set(marg_rvs)
if len(mutual) == len(self.scope.query):
return None
marg_input = self.inputs.marginalize(marg_rvs, prune=prune, cache=cache)
if marg_input is None:
return None
return NoisySum(
inputs=marg_input,
out_channels=self.out_shape.channels,
noise_std=self.noise_std,
num_repetitions=self.out_shape.repetitions,
)
[13]:
# Test NoisySum
leaf = NoisyNormal(scope=Scope([0]), out_channels=4)
noisy_sum = NoisySum(inputs=leaf, out_channels=2, noise_std=0.5)
data = torch.randn(10, 1)
# Training mode: noise added
noisy_sum.train()
ll1 = noisy_sum.log_likelihood(data)
ll2 = noisy_sum.log_likelihood(data)
print(f"Training: outputs differ = {not torch.allclose(ll1, ll2)}")
# Eval mode: deterministic
noisy_sum.eval()
ll1 = noisy_sum.log_likelihood(data)
ll2 = noisy_sum.log_likelihood(data)
print(f"Eval: outputs identical = {torch.allclose(ll1, ll2)}")
Training: outputs differ = True
Eval: outputs identical = True
Key points:
Use
@cacheddecorator to enable caching for sampling and EMWeights have shape
(features, in_channels, out_channels, repetitions)Use
proj_convex_to_real()to convert probabilities to unconstrained logits
How Sum Sampling Works:
Sum modules select paths through the DAG—they don’t generate samples.
Receive current
sampling_ctx.channel_indexfrom parentGather logits for those specific output channels
Sample from
Categorical(logits)(orargmaxfor MPE)Update
sampling_ctx.channel_indexwith selected child indicesCall
self.inputs.sample(...)to continue traversal
The key code pattern:
# Select which input channel to use for each sample
if is_mpe:
new_channel_index = torch.argmax(logits, dim=-1)
else:
new_channel_index = Categorical(logits=logits).sample()
# Update context and delegate to children
sampling_ctx.channel_index = new_channel_index
self.inputs.sample(data=data, sampling_ctx=sampling_ctx, ...)
4. Implementing a Product Module¶
Product modules compute joint distributions: \(p(x_1, x_2) = p(x_1) \cdot p(x_2)\).
To implement a product module from scratch, you need:
Concatenate multiple inputs via
Cat(along feature dimension)log_likelihood()that sums log-probs across featuressample()that delegates to the input with expanded contextfeature_to_scopethat joins all input scopes
Example: NoisyProduct¶
A product module that adds Gaussian noise during training for regularization. This demonstrates how to add training-time behavior to a product.
[14]:
class NoisyProduct(Module):
"""Product module with additive Gaussian noise during training.
Adds noise to log-likelihoods during training for regularization.
Deterministic during evaluation.
"""
def __init__(self, inputs, noise_std: float = 0.1):
super().__init__()
# Handle single module or list of modules
if isinstance(inputs, list):
self.inputs = Cat(inputs, dim=1) if len(inputs) > 1 else inputs[0]
else:
self.inputs = inputs
self.scope = self.inputs.scope
self.noise_std = noise_std
# Shape: product reduces features to 1
in_shape = self.inputs.out_shape
self.in_shape = in_shape
self.out_shape = ModuleShape(1, in_shape.channels, in_shape.repetitions)
@property
def feature_to_scope(self) -> np.ndarray:
# Join all input scopes into a single scope per repetition
out = []
for r in range(self.out_shape.repetitions):
joined = Scope.join_all(self.inputs.feature_to_scope[:, r])
out.append(np.array([[joined]]))
return np.concatenate(out, axis=1)
@cached
def log_likelihood(self, data: Tensor, cache: Cache | None = None) -> Tensor:
# Get input log-likelihoods: (batch, features, channels, reps)
ll = self.inputs.log_likelihood(data, cache=cache)
# Product = sum in log-space, reduce over features (dim=1)
result = torch.sum(ll, dim=1, keepdim=True)
# Add noise during training only
if self.training:
noise = torch.randn_like(result) * self.noise_std
result = result + noise
return result # Shape: (batch, 1, channels, reps)
def sample(
self,
num_samples: int | None = None,
data: Tensor | None = None,
is_mpe: bool = False,
cache: Cache | None = None,
sampling_ctx: SamplingContext | None = None,
) -> Tensor:
data = self._prepare_sample_data(num_samples, data)
sampling_ctx = init_default_sampling_context(sampling_ctx, data.shape[0], data.device)
# Expand context to match input feature count
in_features = self.inputs.out_shape.features
channel_index = sampling_ctx.channel_index.expand(-1, in_features)
mask = sampling_ctx.mask.expand(-1, in_features)
sampling_ctx.update(channel_index=channel_index, mask=mask)
self.inputs.sample(data=data, is_mpe=is_mpe, cache=cache, sampling_ctx=sampling_ctx)
return data
def marginalize(self, marg_rvs: list[int], prune: bool = True, cache=None):
mutual = set(self.scope.query) & set(marg_rvs)
if len(mutual) == len(self.scope.query):
return None
marg_input = self.inputs.marginalize(marg_rvs, prune=prune, cache=cache)
if marg_input is None:
return None
if prune and marg_input.out_shape.features == 1:
return marg_input
return NoisyProduct(inputs=marg_input, noise_std=self.noise_std)
[15]:
# Test NoisyProduct
from spflow.modules.leaves import Normal
leaf1 = Normal(scope=Scope([0]), out_channels=2)
leaf2 = Normal(scope=Scope([1]), out_channels=2)
noisy_prod = NoisyProduct(inputs=[leaf1, leaf2], noise_std=0.5)
# Verify feature_to_scope matches expected joined scope
print(f"Product scope: {noisy_prod.scope}")
assert len(noisy_prod.scope) == 2
data = torch.randn(5, 2)
# Training mode: noise is added (outputs differ each call)
noisy_prod.train()
ll1 = noisy_prod.log_likelihood(data)
ll2 = noisy_prod.log_likelihood(data)
print(f"Training: outputs differ = {not torch.allclose(ll1, ll2)}")
# Eval mode: deterministic (outputs identical)
noisy_prod.eval()
ll1 = noisy_prod.log_likelihood(data)
ll2 = noisy_prod.log_likelihood(data)
print(f"Eval: outputs identical = {torch.allclose(ll1, ll2)}")
# Verify sampling works and produces correct shape
samples = noisy_prod.sample(num_samples=100)
print(f"Samples shape: {samples.shape}")
assert samples.shape == (100, 2)
Product scope: (0, 1)
Training: outputs differ = True
Eval: outputs identical = True
Samples shape: torch.Size([100, 2])
Key points:
Products sum log-likelihoods across features (dim=1)
Cat(inputs, dim=1)concatenates inputs along the feature dimensionUse
self.trainingto differentiate train/eval behavior
How Product Sampling Works:
Products represent factorization: \(p(X_1, X_2) = p(X_1) \cdot p(X_2)\). They expand the sampling context but don’t select paths.
Inputs have disjoint scopes (different random variables)
Expand
channel_indexfrom(batch, 1)to(batch, num_input_features)Expand
masksimilarlyPass expanded context to children—no selection happens
The key code pattern:
# Expand context to match number of input features
channel_index = sampling_ctx.channel_index.expand(-1, self.inputs.out_shape.features)
mask = sampling_ctx.mask.expand(-1, self.inputs.out_shape.features)
sampling_ctx.update(channel_index=channel_index, mask=mask)
# Delegate to children
self.inputs.sample(data=data, sampling_ctx=sampling_ctx, ...)
Products have no learnable parameters—they are purely structural.
5. Implementing a Split Module¶
Split modules partition an input module into multiple groups. They provide different views of the same input.
Example: RandomSplit¶
A RandomSplit assigns features to groups randomly (fixed at initialization).
[16]:
from spflow.modules.ops.split import Split
from spflow.modules.module import Module
class RandomSplit(Split):
"""Split features into groups via random assignment."""
def __init__(self, inputs: Module, num_splits: int = 2, seed: int = 42):
super().__init__(inputs=inputs, dim=1, num_splits=num_splits)
# Randomly assign each feature to a split
gen = torch.Generator().manual_seed(seed)
num_features = inputs.out_shape.features
assignments = torch.randint(0, num_splits, (num_features,), generator=gen)
# Create boolean masks for each split
self.split_masks = [assignments == i for i in range(num_splits)]
# Store assignments for merge_split_indices
self.register_buffer("_assignments", assignments)
@property
def feature_to_scope(self):
scopes = self.inputs.feature_to_scope
return [
[scopes[j] for j in range(len(scopes)) if self.split_masks[i][j]]
for i in range(self.num_splits)
]
@cached
def log_likelihood(self, data, cache=None):
lls = self.inputs.log_likelihood(data, cache=cache)
return [lls[:, mask, ...] for mask in self.split_masks]
def merge_split_indices(self, *split_indices: Tensor) -> Tensor:
batch_size = split_indices[0].shape[0]
num_features = self.inputs.out_shape.features
# Create output tensor
result = torch.zeros(batch_size, num_features, dtype=split_indices[0].dtype, device=split_indices[0].device)
# Track position within each split
split_positions = [0] * self.num_splits
# Scatter indices back to original positions
for feature_idx in range(num_features):
split_idx = self._assignments[feature_idx].item()
pos = split_positions[split_idx]
result[:, feature_idx] = split_indices[split_idx][:, pos]
split_positions[split_idx] += 1
return result
[17]:
# Test RandomSplit
leaf = Normal(scope=Scope(list(range(6))), out_channels=2)
split = RandomSplit(inputs=leaf, num_splits=2, seed=123)
data = torch.randn(5, 6)
lls = split.log_likelihood(data)
for i, ll in enumerate(lls):
print(f"Split {i} shape: {ll.shape}")
Split 0 shape: torch.Size([5, 5, 2, 1])
Split 1 shape: torch.Size([5, 1, 2, 1])
Key points:
Split modules return a list of tensors (one per split)
self.inputsis a single module in Split (unlike Sum/Product which can wrap multiple)feature_to_scopemust return a list of scope lists, one per split
6. Testing Your Module¶
SPFlow provides test utilities in tests/utils/. Here’s a minimal test pattern:
[18]:
def test_noisy_normal_leaf():
"""Test basic functionality of NoisyNormal leaf."""
leaf = NoisyNormal(scope=Scope([0]), out_channels=4, num_repetitions=1, noise_std=0.1)
# Check shapes
assert leaf.out_shape.features == 1
assert leaf.out_shape.channels == 4
assert leaf.out_shape.repetitions == 1
# Check log-likelihood
data = torch.randn(10, 1)
leaf.eval() # Deterministic mode for testing
ll = leaf.log_likelihood(data)
assert ll.shape == (10, 1, 4, 1)
assert torch.isfinite(ll).all()
# Check sampling
samples = leaf.sample(num_samples=100)
assert samples.shape == (100, 1)
test_noisy_normal_leaf()
print("All tests passed!")
All tests passed!
Contributing to SPFlow: If you want to contribute your module to the SPFlow repository, unit tests are required. See
tests/for examples and usepytestwith parametrization for thorough coverage.
7. Reference Implementations¶
For complete examples, see:
Module Type |
Reference File |
|---|---|
Leaf |
|
Sum |
|
Product |
|
Split |
|