Developer Guide

This guide shows how to extend SPFlow by implementing custom modules. By the end, you’ll understand how to create:

  • Leaf modules — distributions at the input layer

  • Sum modules — weighted mixtures of child distributions

  • Product modules — factorizations via conditional independence

  • Split modules — utilities for partitioning inputs

Prerequisites: Familiarity with PyTorch nn.Module and basic probability theory.

1. Core Concepts

1.1 Module Hierarchy

All SPFlow modules inherit from Module, which extends torch.nn.Module:

nn.Module
    └── Module (abstract base)
            ├── LeafModule (distributions)
            ├── Sum (weighted mixtures)
            ├── Product (factorization)
            └── Split (partitioning)

1.2 Shape System

SPFlow uses a 3-tuple ModuleShape(features, channels, repetitions) to describe tensor dimensions:

Dimension

Meaning

Features

Number of scope partitions (random variable groupings)

Channels

Parallel distributions per feature

Repetitions

Independent copies of the structure

All intermediate tensors have shape: (batch, features, channels, repetitions).

1.3 Required Interface

Leaf Modules

Intermediate Modules

params()

log_likelihood(data, cache)

_torch_distribution_class

sample(...)

_compute_parameter_estimates(...)

marginalize(...)

_set_mle_parameters(...)

expectation_maximization(...)

1.4 Sampling Architecture: Top-Down Index Propagation

SPFlow uses ancestral sampling with a unique top-down index propagation strategy. Understanding this is crucial for implementing custom modules correctly.

Key insight: Internal nodes (Sum, Product) don’t generate samples—they only update routing indices. Only leaf nodes actually sample from distributions and write values to the output tensor.

The sampling process works as follows:

  1. A data tensor filled with NaN is passed through the entire circuit

  2. A SamplingContext tracks which path to follow through the DAG

  3. Sum nodes select which child channel to sample via their weights

  4. Product nodes expand the context to cover all input features

  5. Leaf nodes generate samples and write them in-place to data

Sampling Flow Example:

Root (Sum)
    ├── samples from Categorical(weights) to pick child index
    └── updates sampling_ctx.channel_index with selected child
           ↓
Product
    ├── expands channel_index from (batch, 1) to (batch, num_features)
    └── passes expanded context to child
           ↓
Leaf (Normal)
    ├── uses channel_index to select which channel's parameters
    ├── uses repetition_idx to select which repetition's parameters
    └── writes samples to data[:, self.scope.query] in-place

This design is efficient (no intermediate tensor allocation) and correct (consistent paths through the circuit).

1.5 The SamplingContext Class

The SamplingContext class manages routing state during sampling. It contains:

Field

Shape

Purpose

channel_index

(batch, features)

Which output channel to use at each position

mask

(batch, features)

Boolean mask—which positions need sampling

repetition_idx

(batch,)

Which repetition to use (for multi-repetition circuits)

Why ``channel_index``?

  • Sum modules have multiple output channels (mixture components)

  • During sampling, we must pick exactly one path through the DAG

  • Parent Sum nodes set channel_index[sample_i, feature_j] = index of selected child

  • Children use this to gather the correct logits/parameters

Why ``repetition_idx``?

  • Circuits with num_repetitions > 1 have parallel independent copies

  • RepetitionMixingLayer selects which repetition to use per-sample

  • Leaves use it to index their 3D parameter tensors (features, channels, repetitions)

When implementing a custom module’s sample() method, you typically:

  1. Initialize context: sampling_ctx = init_default_sampling_context(sampling_ctx, batch_size, device)

  2. Use current indices to select parameters/weights

  3. Update channel_index and/or mask for children

  4. Call self.inputs.sample(data=data, sampling_ctx=sampling_ctx, ...)

2. Implementing a Leaf Module

Leaf modules wrap probability distributions. The base class LeafModule handles most functionality—you only need to define:

  1. Distribution parameters as nn.Parameter

  2. A params() method returning a dict of parameters

  3. The PyTorch distribution class to use

  4. MLE estimation logic (for parameter learning)

Example: NoisyNormal

A Normal distribution that adds noise to log-likelihoods during training. This demonstrates extending a standard distribution with training-time regularization.

[10]:
import torch
from torch import Tensor, nn
from spflow.modules.leaves.leaf import LeafModule
from spflow.utils.leaves import init_parameter


class NoisyNormal(LeafModule):
    """Normal distribution with additive noise during training.

    Adds Gaussian noise to log-likelihoods during training for regularization.
    Deterministic during evaluation.
    """

    def __init__(self, scope, out_channels=None, num_repetitions=1,
                 parameter_fn=None, validate_args=True, loc=None, scale=None,
                 noise_std: float = 0.1):
        super().__init__(
            scope=scope, out_channels=out_channels,
            num_repetitions=num_repetitions, params=[loc, scale],
            parameter_fn=parameter_fn, validate_args=validate_args,
        )
        # Initialize loc and scale (scale stored in log-space for positivity)
        loc = init_parameter(loc, self._event_shape, init=torch.zeros)
        scale = init_parameter(scale, self._event_shape, init=torch.ones)
        self.loc = nn.Parameter(loc)
        self.log_scale = nn.Parameter(torch.log(scale))
        self.noise_std = noise_std

    @property
    def scale(self):
        return torch.exp(self.log_scale)

    @property
    def _supported_value(self):
        return 0.0  # Mean is always in support

    @property
    def _torch_distribution_class(self):
        return torch.distributions.Normal

    def params(self):
        return {"loc": self.loc, "scale": self.scale}

    def log_likelihood(self, data, cache=None):
        # Overwrite LeafModule implementation to add noise during training

        # Call LeafModule implementation
        ll = super().log_likelihood(data, cache=cache)

        # Add noise during training only
        if self.training:
            noise = torch.randn_like(ll) * self.noise_std
            ll = ll + noise

        return ll

    def _compute_parameter_estimates(self, data, weights, bias_correction):
        # MLE for Normal: loc = weighted mean, scale = weighted std
        n = weights.sum(dim=0)
        mean = (weights * data).sum(dim=0) / n
        var = (weights * (data - mean) ** 2).sum(dim=0) / n
        return {"loc": mean, "scale": torch.sqrt(var + 1e-8)}

    def _set_mle_parameters(self, params_dict):
        self.loc.data = params_dict["loc"]
        self.log_scale.data = torch.log(params_dict["scale"])
[11]:
# Quick test
from spflow.meta import Scope

leaf = NoisyNormal(scope=Scope([0]), out_channels=3, noise_std=0.5)

# Training mode: noise added
leaf.train()
data = torch.randn(5, 1)
ll1 = leaf.log_likelihood(data)
ll2 = leaf.log_likelihood(data)
print(f"Training: outputs differ = {not torch.allclose(ll1, ll2)}")

# Eval mode: deterministic
leaf.eval()
ll1 = leaf.log_likelihood(data)
ll2 = leaf.log_likelihood(data)
print(f"Eval: outputs identical = {torch.allclose(ll1, ll2)}")
Training: outputs differ = True
Eval: outputs identical = True

Key points:

  • Store constrained parameters in transformed space (e.g., log_scale for positive scale)

  • init_parameter() handles shape inference from out_channels

  • Override log_likelihood() to add custom behavior while calling super()

How Leaf Sampling Works:

Leaves are the only modules that actually generate samples. The sampling flow is:

  1. Receive data tensor with NaN at positions to sample

  2. Use sampling_ctx.channel_index to select which channel’s parameters

  3. Use sampling_ctx.repetition_idx to select which repetition’s parameters

  4. Sample from the distribution (or take mode for MPE)

  5. Write samples in-place to data[:, self.scope.query]

3. Implementing a Sum Module

Sum modules compute weighted mixtures: \(p(x) = \sum_i w_i \cdot p_i(x)\).

We now implement a custom sum module from scratch to demonstrate what is necessary to extend SPFlow with new sum-like operations. Our NoisySum adds noise during training for regularization:

\[\log p(x) = \log \sum_i w_i \cdot p_i(x) + \epsilon, \quad \epsilon \sim \mathcal{N}(0, \sigma^2)\]

To implement any sum module, you need:

  1. Weight parameters (stored as logits for unconstrained optimization)

  2. log_likelihood() using logsumexp for numerical stability

  3. sample() that selects input channels based on weights

  4. feature_to_scope property mapping features to scopes

Example: NoisySum

A sum module that adds Gaussian noise during training.

[12]:
import numpy as np
from spflow.modules.module import Module
from spflow.modules.module_shape import ModuleShape
from spflow.modules.ops.cat import Cat
from spflow.utils.cache import Cache, cached
from spflow.utils.projections import proj_convex_to_real
from spflow.utils.sampling_context import SamplingContext, init_default_sampling_context


class NoisySum(Module):
    """Sum module with additive noise during training.

    Adds Gaussian noise to log-likelihoods during training for regularization.
    Deterministic during evaluation.
    """

    def __init__(self, inputs, out_channels: int, noise_std: float = 0.1, num_repetitions: int = 1):
        super().__init__()

        # Handle single module or list of modules
        if isinstance(inputs, list):
            self.inputs = Cat(inputs, dim=2) if len(inputs) > 1 else inputs[0]
        else:
            self.inputs = inputs

        self.scope = self.inputs.scope
        self.noise_std = noise_std

        # Shape computation
        in_shape = self.inputs.out_shape
        self.in_shape = in_shape
        self.out_shape = ModuleShape(in_shape.features, out_channels, num_repetitions)

        # Weight shape: (features, in_channels, out_channels, repetitions)
        self._weights_shape = (
            in_shape.features, in_shape.channels, out_channels, num_repetitions
        )

        # Initialize weights randomly (store as logits for unconstrained optimization)
        weights = torch.rand(self._weights_shape) + 1e-8
        weights /= weights.sum(dim=1, keepdim=True)
        self.logits = nn.Parameter(proj_convex_to_real(weights))

    @property
    def feature_to_scope(self) -> np.ndarray:
        return self.inputs.feature_to_scope

    @property
    def log_weights(self) -> Tensor:
        """Log-normalized weights via log_softmax."""
        return torch.nn.functional.log_softmax(self.logits, dim=1)

    @property
    def weights(self) -> Tensor:
        """Normalized weights via softmax."""
        return torch.nn.functional.softmax(self.logits, dim=1)

    @cached
    def log_likelihood(self, data: Tensor, cache: Cache | None = None) -> Tensor:
        # Input shape: (batch, features, in_channels, reps)
        ll = self.inputs.log_likelihood(data, cache=cache)

        # Expand for out_channels: (batch, features, in_channels, 1, reps)
        ll = ll.unsqueeze(3)

        # Weights: (1, features, in_channels, out_channels, reps)
        log_w = self.log_weights.unsqueeze(0)

        # Weighted sum via logsumexp over in_channels (dim=2)
        result = torch.logsumexp(ll + log_w, dim=2)

        # Add noise during training only
        if self.training:
            noise = torch.randn_like(result) * self.noise_std
            result = result + noise

        return result  # Shape: (batch, features, out_channels, reps)

    def sample(
        self,
        num_samples: int | None = None,
        data: Tensor | None = None,
        is_mpe: bool = False,
        cache: Cache | None = None,
        sampling_ctx: SamplingContext | None = None,
    ) -> Tensor:
        data = self._prepare_sample_data(num_samples, data)
        sampling_ctx = init_default_sampling_context(sampling_ctx, data.shape[0], data.device)

        logits = self.logits[..., 0]
        logits = logits.unsqueeze(0).expand(data.shape[0], -1, -1, -1)

        # Gather logits for selected out_channels
        idxs = sampling_ctx.channel_index.unsqueeze(-1).unsqueeze(-1)
        idxs = idxs.expand(-1, -1, logits.shape[2], -1)
        logits = logits.gather(dim=3, index=idxs).squeeze(3)

        # Select input channels: either argmax (MPE) or sample
        if is_mpe:
            new_channels = logits.argmax(dim=-1)
        else:
            new_channels = torch.distributions.Categorical(logits=logits).sample()

        sampling_ctx.channel_index = new_channels
        self.inputs.sample(data=data, is_mpe=is_mpe, cache=cache, sampling_ctx=sampling_ctx)
        return data

    def marginalize(self, marg_rvs: list[int], prune: bool = True, cache=None):
        mutual = set(self.scope.query) & set(marg_rvs)
        if len(mutual) == len(self.scope.query):
            return None

        marg_input = self.inputs.marginalize(marg_rvs, prune=prune, cache=cache)
        if marg_input is None:
            return None

        return NoisySum(
            inputs=marg_input,
            out_channels=self.out_shape.channels,
            noise_std=self.noise_std,
            num_repetitions=self.out_shape.repetitions,
        )
[13]:
# Test NoisySum

leaf = NoisyNormal(scope=Scope([0]), out_channels=4)
noisy_sum = NoisySum(inputs=leaf, out_channels=2, noise_std=0.5)

data = torch.randn(10, 1)

# Training mode: noise added
noisy_sum.train()
ll1 = noisy_sum.log_likelihood(data)
ll2 = noisy_sum.log_likelihood(data)
print(f"Training: outputs differ = {not torch.allclose(ll1, ll2)}")

# Eval mode: deterministic
noisy_sum.eval()
ll1 = noisy_sum.log_likelihood(data)
ll2 = noisy_sum.log_likelihood(data)
print(f"Eval: outputs identical = {torch.allclose(ll1, ll2)}")


Training: outputs differ = True
Eval: outputs identical = True

Key points:

  • Use @cached decorator to enable caching for sampling and EM

  • Weights have shape (features, in_channels, out_channels, repetitions)

  • Use proj_convex_to_real() to convert probabilities to unconstrained logits

How Sum Sampling Works:

Sum modules select paths through the DAG—they don’t generate samples.

  1. Receive current sampling_ctx.channel_index from parent

  2. Gather logits for those specific output channels

  3. Sample from Categorical(logits) (or argmax for MPE)

  4. Update sampling_ctx.channel_index with selected child indices

  5. Call self.inputs.sample(...) to continue traversal

The key code pattern:

# Select which input channel to use for each sample
if is_mpe:
    new_channel_index = torch.argmax(logits, dim=-1)
else:
    new_channel_index = Categorical(logits=logits).sample()

# Update context and delegate to children
sampling_ctx.channel_index = new_channel_index
self.inputs.sample(data=data, sampling_ctx=sampling_ctx, ...)

4. Implementing a Product Module

Product modules compute joint distributions: \(p(x_1, x_2) = p(x_1) \cdot p(x_2)\).

To implement a product module from scratch, you need:

  1. Concatenate multiple inputs via Cat (along feature dimension)

  2. log_likelihood() that sums log-probs across features

  3. sample() that delegates to the input with expanded context

  4. feature_to_scope that joins all input scopes

Example: NoisyProduct

A product module that adds Gaussian noise during training for regularization. This demonstrates how to add training-time behavior to a product.

\[\log p(x) = \sum_j \log p_j(x_j) + \epsilon, \quad \epsilon \sim \mathcal{N}(0, \sigma^2)\]
[14]:
class NoisyProduct(Module):
    """Product module with additive Gaussian noise during training.

    Adds noise to log-likelihoods during training for regularization.
    Deterministic during evaluation.
    """

    def __init__(self, inputs, noise_std: float = 0.1):
        super().__init__()

        # Handle single module or list of modules
        if isinstance(inputs, list):
            self.inputs = Cat(inputs, dim=1) if len(inputs) > 1 else inputs[0]
        else:
            self.inputs = inputs

        self.scope = self.inputs.scope
        self.noise_std = noise_std

        # Shape: product reduces features to 1
        in_shape = self.inputs.out_shape
        self.in_shape = in_shape
        self.out_shape = ModuleShape(1, in_shape.channels, in_shape.repetitions)

    @property
    def feature_to_scope(self) -> np.ndarray:
        # Join all input scopes into a single scope per repetition
        out = []
        for r in range(self.out_shape.repetitions):
            joined = Scope.join_all(self.inputs.feature_to_scope[:, r])
            out.append(np.array([[joined]]))
        return np.concatenate(out, axis=1)

    @cached
    def log_likelihood(self, data: Tensor, cache: Cache | None = None) -> Tensor:
        # Get input log-likelihoods: (batch, features, channels, reps)
        ll = self.inputs.log_likelihood(data, cache=cache)

        # Product = sum in log-space, reduce over features (dim=1)
        result = torch.sum(ll, dim=1, keepdim=True)

        # Add noise during training only
        if self.training:
            noise = torch.randn_like(result) * self.noise_std
            result = result + noise

        return result  # Shape: (batch, 1, channels, reps)

    def sample(
        self,
        num_samples: int | None = None,
        data: Tensor | None = None,
        is_mpe: bool = False,
        cache: Cache | None = None,
        sampling_ctx: SamplingContext | None = None,
    ) -> Tensor:
        data = self._prepare_sample_data(num_samples, data)
        sampling_ctx = init_default_sampling_context(sampling_ctx, data.shape[0], data.device)

        # Expand context to match input feature count
        in_features = self.inputs.out_shape.features
        channel_index = sampling_ctx.channel_index.expand(-1, in_features)
        mask = sampling_ctx.mask.expand(-1, in_features)
        sampling_ctx.update(channel_index=channel_index, mask=mask)

        self.inputs.sample(data=data, is_mpe=is_mpe, cache=cache, sampling_ctx=sampling_ctx)
        return data

    def marginalize(self, marg_rvs: list[int], prune: bool = True, cache=None):
        mutual = set(self.scope.query) & set(marg_rvs)
        if len(mutual) == len(self.scope.query):
            return None

        marg_input = self.inputs.marginalize(marg_rvs, prune=prune, cache=cache)
        if marg_input is None:
            return None

        if prune and marg_input.out_shape.features == 1:
            return marg_input

        return NoisyProduct(inputs=marg_input, noise_std=self.noise_std)
[15]:
# Test NoisyProduct
from spflow.modules.leaves import Normal

leaf1 = Normal(scope=Scope([0]), out_channels=2)
leaf2 = Normal(scope=Scope([1]), out_channels=2)
noisy_prod = NoisyProduct(inputs=[leaf1, leaf2], noise_std=0.5)

# Verify feature_to_scope matches expected joined scope
print(f"Product scope: {noisy_prod.scope}")
assert len(noisy_prod.scope) == 2

data = torch.randn(5, 2)

# Training mode: noise is added (outputs differ each call)
noisy_prod.train()
ll1 = noisy_prod.log_likelihood(data)
ll2 = noisy_prod.log_likelihood(data)
print(f"Training: outputs differ = {not torch.allclose(ll1, ll2)}")

# Eval mode: deterministic (outputs identical)
noisy_prod.eval()
ll1 = noisy_prod.log_likelihood(data)
ll2 = noisy_prod.log_likelihood(data)
print(f"Eval: outputs identical = {torch.allclose(ll1, ll2)}")

# Verify sampling works and produces correct shape
samples = noisy_prod.sample(num_samples=100)
print(f"Samples shape: {samples.shape}")
assert samples.shape == (100, 2)
Product scope: (0, 1)
Training: outputs differ = True
Eval: outputs identical = True
Samples shape: torch.Size([100, 2])

Key points:

  • Products sum log-likelihoods across features (dim=1)

  • Cat(inputs, dim=1) concatenates inputs along the feature dimension

  • Use self.training to differentiate train/eval behavior

How Product Sampling Works:

Products represent factorization: \(p(X_1, X_2) = p(X_1) \cdot p(X_2)\). They expand the sampling context but don’t select paths.

  1. Inputs have disjoint scopes (different random variables)

  2. Expand channel_index from (batch, 1) to (batch, num_input_features)

  3. Expand mask similarly

  4. Pass expanded context to children—no selection happens

The key code pattern:

# Expand context to match number of input features
channel_index = sampling_ctx.channel_index.expand(-1, self.inputs.out_shape.features)
mask = sampling_ctx.mask.expand(-1, self.inputs.out_shape.features)
sampling_ctx.update(channel_index=channel_index, mask=mask)

# Delegate to children
self.inputs.sample(data=data, sampling_ctx=sampling_ctx, ...)

Products have no learnable parameters—they are purely structural.

5. Implementing a Split Module

Split modules partition an input module into multiple groups. They provide different views of the same input.

Example: RandomSplit

A RandomSplit assigns features to groups randomly (fixed at initialization).

[16]:
from spflow.modules.ops.split import Split
from spflow.modules.module import Module


class RandomSplit(Split):
    """Split features into groups via random assignment."""

    def __init__(self, inputs: Module, num_splits: int = 2, seed: int = 42):
        super().__init__(inputs=inputs, dim=1, num_splits=num_splits)

        # Randomly assign each feature to a split
        gen = torch.Generator().manual_seed(seed)
        num_features = inputs.out_shape.features
        assignments = torch.randint(0, num_splits, (num_features,), generator=gen)

        # Create boolean masks for each split
        self.split_masks = [assignments == i for i in range(num_splits)]

        # Store assignments for merge_split_indices
        self.register_buffer("_assignments", assignments)

    @property
    def feature_to_scope(self):
        scopes = self.inputs.feature_to_scope
        return [
            [scopes[j] for j in range(len(scopes)) if self.split_masks[i][j]]
            for i in range(self.num_splits)
        ]

    @cached
    def log_likelihood(self, data, cache=None):
        lls = self.inputs.log_likelihood(data, cache=cache)
        return [lls[:, mask, ...] for mask in self.split_masks]


    def merge_split_indices(self, *split_indices: Tensor) -> Tensor:
        batch_size = split_indices[0].shape[0]
        num_features = self.inputs.out_shape.features
        # Create output tensor
        result = torch.zeros(batch_size, num_features, dtype=split_indices[0].dtype, device=split_indices[0].device)
        # Track position within each split
        split_positions = [0] * self.num_splits
        # Scatter indices back to original positions
        for feature_idx in range(num_features):
            split_idx = self._assignments[feature_idx].item()
            pos = split_positions[split_idx]
            result[:, feature_idx] = split_indices[split_idx][:, pos]
            split_positions[split_idx] += 1
        return result
[17]:
# Test RandomSplit
leaf = Normal(scope=Scope(list(range(6))), out_channels=2)
split = RandomSplit(inputs=leaf, num_splits=2, seed=123)

data = torch.randn(5, 6)
lls = split.log_likelihood(data)

for i, ll in enumerate(lls):
    print(f"Split {i} shape: {ll.shape}")


Split 0 shape: torch.Size([5, 5, 2, 1])
Split 1 shape: torch.Size([5, 1, 2, 1])

Key points:

  • Split modules return a list of tensors (one per split)

  • self.inputs is a single module in Split (unlike Sum/Product which can wrap multiple)

  • feature_to_scope must return a list of scope lists, one per split

6. Testing Your Module

SPFlow provides test utilities in tests/utils/. Here’s a minimal test pattern:

[18]:
def test_noisy_normal_leaf():
    """Test basic functionality of NoisyNormal leaf."""
    leaf = NoisyNormal(scope=Scope([0]), out_channels=4, num_repetitions=1, noise_std=0.1)

    # Check shapes
    assert leaf.out_shape.features == 1
    assert leaf.out_shape.channels == 4
    assert leaf.out_shape.repetitions == 1

    # Check log-likelihood
    data = torch.randn(10, 1)
    leaf.eval()  # Deterministic mode for testing
    ll = leaf.log_likelihood(data)
    assert ll.shape == (10, 1, 4, 1)
    assert torch.isfinite(ll).all()

    # Check sampling
    samples = leaf.sample(num_samples=100)
    assert samples.shape == (100, 1)

test_noisy_normal_leaf()
print("All tests passed!")
All tests passed!

Contributing to SPFlow: If you want to contribute your module to the SPFlow repository, unit tests are required. See tests/ for examples and use pytest with parametrization for thorough coverage.

7. Reference Implementations

For complete examples, see:

Module Type

Reference File

Leaf

spflow/modules/leaves/normal.py (~115 lines)

Sum

spflow/modules/sums/sum.py

Product

spflow/modules/products/product.py (~210 lines)

Split

spflow/modules/ops/split.py