{ "cells": [ { "cell_type": "markdown", "id": "d9869dc1", "metadata": {}, "source": [ "# Developer Guide\n", "\n", "This guide shows how to extend SPFlow by implementing custom modules. By the end, you'll understand how to create:\n", "\n", "- **Leaf modules** — distributions at the input layer\n", "- **Sum modules** — weighted mixtures of child distributions\n", "- **Product modules** — factorizations via conditional independence\n", "- **Split modules** — utilities for partitioning inputs\n", "\n", "> **Prerequisites:** Familiarity with PyTorch `nn.Module` and basic probability theory.\n", "\n", "## 1. Core Concepts\n", "\n", "### 1.1 Module Hierarchy\n", "\n", "All SPFlow modules inherit from `Module`, which extends `torch.nn.Module`:\n", "\n", "```\n", "nn.Module\n", " └── Module (abstract base)\n", " ├── LeafModule (distributions)\n", " ├── Sum (weighted mixtures)\n", " ├── Product (factorization)\n", " └── Split (partitioning)\n", "```\n", "\n", "### 1.2 Shape System\n", "\n", "SPFlow uses a 3-tuple `ModuleShape(features, channels, repetitions)` to describe tensor dimensions:\n", "\n", "| Dimension | Meaning |\n", "|-----------|--------|\n", "| **Features** | Number of scope partitions (random variable groupings) |\n", "| **Channels** | Parallel distributions per feature |\n", "| **Repetitions** | Independent copies of the structure |\n", "\n", "All intermediate tensors have shape: `(batch, features, channels, repetitions)`.\n", "\n", "### 1.3 Required Interface\n", "\n", "| Leaf Modules | Intermediate Modules |\n", "|-------------|---------------------|\n", "| `params()` | `log_likelihood(data, cache)` |\n", "| `_torch_distribution_class` | `sample(...)` |\n", "| `_compute_parameter_estimates(...)` | `marginalize(...)` |\n", "| `_set_mle_parameters(...)` | `expectation_maximization(...)` |\n", "\n", "### 1.4 Sampling Architecture: Top-Down Index Propagation\n", "\n", "SPFlow uses **ancestral sampling** with a unique top-down index propagation strategy.\n", "Understanding this is crucial for implementing custom modules correctly.\n", "\n", "**Key insight:** Internal nodes (Sum, Product) **don't generate samples**—they only update\n", "routing indices. Only **leaf nodes** actually sample from distributions and write values\n", "to the output tensor.\n", "\n", "The sampling process works as follows:\n", "\n", "1. A `data` tensor filled with `NaN` is passed through the entire circuit\n", "2. A `SamplingContext` tracks which path to follow through the DAG\n", "3. Sum nodes **select which child channel** to sample via their weights\n", "4. Product nodes **expand the context** to cover all input features\n", "5. Leaf nodes **generate samples** and write them in-place to `data`\n", "\n", "```\n", "Sampling Flow Example:\n", "\n", "Root (Sum)\n", " ├── samples from Categorical(weights) to pick child index\n", " └── updates sampling_ctx.channel_index with selected child\n", " ↓\n", "Product\n", " ├── expands channel_index from (batch, 1) to (batch, num_features)\n", " └── passes expanded context to child\n", " ↓\n", "Leaf (Normal)\n", " ├── uses channel_index to select which channel's parameters\n", " ├── uses repetition_idx to select which repetition's parameters\n", " └── writes samples to data[:, self.scope.query] in-place\n", "```\n", "\n", "This design is efficient (no intermediate tensor allocation) and correct (consistent\n", "paths through the circuit).\n", "\n", "### 1.5 The `SamplingContext` Class\n", "\n", "The `SamplingContext` class manages routing state during sampling. It contains:\n", "\n", "| Field | Shape | Purpose |\n", "|-------|-------|---------|\n", "| `channel_index` | `(batch, features)` | Which output channel to use at each position |\n", "| `mask` | `(batch, features)` | Boolean mask—which positions need sampling |\n", "| `repetition_idx` | `(batch,)` | Which repetition to use (for multi-repetition circuits) |\n", "\n", "**Why `channel_index`?**\n", "- Sum modules have multiple output channels (mixture components)\n", "- During sampling, we must pick exactly one path through the DAG\n", "- Parent Sum nodes set `channel_index[sample_i, feature_j]` = index of selected child\n", "- Children use this to gather the correct logits/parameters\n", "\n", "**Why `repetition_idx`?**\n", "- Circuits with `num_repetitions > 1` have parallel independent copies\n", "- `RepetitionMixingLayer` selects which repetition to use per-sample\n", "- Leaves use it to index their 3D parameter tensors `(features, channels, repetitions)`\n", "\n", "When implementing a custom module's `sample()` method, you typically:\n", "1. Initialize context: `sampling_ctx = init_default_sampling_context(sampling_ctx, batch_size, device)`\n", "2. Use current indices to select parameters/weights\n", "3. Update `channel_index` and/or `mask` for children\n", "4. Call `self.inputs.sample(data=data, sampling_ctx=sampling_ctx, ...)`\n", "\n", "## 2. Implementing a Leaf Module\n", "\n", "Leaf modules wrap probability distributions. The base class `LeafModule` handles most functionality—you only need to define:\n", "\n", "1. Distribution parameters as `nn.Parameter`\n", "2. A `params()` method returning a dict of parameters\n", "3. The PyTorch distribution class to use\n", "4. MLE estimation logic (for parameter learning)\n", "\n", "### Example: NoisyNormal\n", "\n", "A Normal distribution that adds noise to log-likelihoods during training.\n", "This demonstrates extending a standard distribution with training-time regularization." ] }, { "cell_type": "code", "execution_count": 10, "id": "ef85e105", "metadata": {}, "outputs": [], "source": [ "import torch\n", "from torch import Tensor, nn\n", "from spflow.modules.leaves.leaf import LeafModule\n", "from spflow.utils.leaves import init_parameter\n", "\n", "\n", "class NoisyNormal(LeafModule):\n", " \"\"\"Normal distribution with additive noise during training.\n", " \n", " Adds Gaussian noise to log-likelihoods during training for regularization.\n", " Deterministic during evaluation.\n", " \"\"\"\n", "\n", " def __init__(self, scope, out_channels=None, num_repetitions=1,\n", " parameter_fn=None, validate_args=True, loc=None, scale=None,\n", " noise_std: float = 0.1):\n", " super().__init__(\n", " scope=scope, out_channels=out_channels,\n", " num_repetitions=num_repetitions, params=[loc, scale],\n", " parameter_fn=parameter_fn, validate_args=validate_args,\n", " )\n", " # Initialize loc and scale (scale stored in log-space for positivity)\n", " loc = init_parameter(loc, self._event_shape, init=torch.zeros)\n", " scale = init_parameter(scale, self._event_shape, init=torch.ones)\n", " self.loc = nn.Parameter(loc)\n", " self.log_scale = nn.Parameter(torch.log(scale))\n", " self.noise_std = noise_std\n", "\n", " @property\n", " def scale(self):\n", " return torch.exp(self.log_scale)\n", "\n", " @property\n", " def _supported_value(self):\n", " return 0.0 # Mean is always in support\n", "\n", " @property\n", " def _torch_distribution_class(self):\n", " return torch.distributions.Normal\n", "\n", " def params(self):\n", " return {\"loc\": self.loc, \"scale\": self.scale}\n", "\n", " def log_likelihood(self, data, cache=None):\n", " # Overwrite LeafModule implementation to add noise during training\n", "\n", " # Call LeafModule implementation\n", " ll = super().log_likelihood(data, cache=cache)\n", " \n", " # Add noise during training only\n", " if self.training:\n", " noise = torch.randn_like(ll) * self.noise_std\n", " ll = ll + noise\n", " \n", " return ll\n", "\n", " def _compute_parameter_estimates(self, data, weights, bias_correction):\n", " # MLE for Normal: loc = weighted mean, scale = weighted std\n", " n = weights.sum(dim=0)\n", " mean = (weights * data).sum(dim=0) / n\n", " var = (weights * (data - mean) ** 2).sum(dim=0) / n\n", " return {\"loc\": mean, \"scale\": torch.sqrt(var + 1e-8)}\n", "\n", " def _set_mle_parameters(self, params_dict):\n", " self.loc.data = params_dict[\"loc\"]\n", " self.log_scale.data = torch.log(params_dict[\"scale\"])" ] }, { "cell_type": "code", "execution_count": 11, "id": "fa585400", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Training: outputs differ = True\n", "Eval: outputs identical = True\n" ] } ], "source": [ "# Quick test\n", "from spflow.meta import Scope\n", "\n", "leaf = NoisyNormal(scope=Scope([0]), out_channels=3, noise_std=0.5)\n", "\n", "# Training mode: noise added\n", "leaf.train()\n", "data = torch.randn(5, 1)\n", "ll1 = leaf.log_likelihood(data)\n", "ll2 = leaf.log_likelihood(data)\n", "print(f\"Training: outputs differ = {not torch.allclose(ll1, ll2)}\")\n", "\n", "# Eval mode: deterministic\n", "leaf.eval()\n", "ll1 = leaf.log_likelihood(data)\n", "ll2 = leaf.log_likelihood(data)\n", "print(f\"Eval: outputs identical = {torch.allclose(ll1, ll2)}\")" ] }, { "cell_type": "markdown", "id": "5c162401", "metadata": {}, "source": [ "**Key points:**\n", "- Store constrained parameters in transformed space (e.g., `log_scale` for positive `scale`)\n", "- `init_parameter()` handles shape inference from `out_channels`\n", "- Override `log_likelihood()` to add custom behavior while calling `super()`\n", "\n", "**How Leaf Sampling Works:**\n", "\n", "Leaves are the **only** modules that actually generate samples. The sampling flow is:\n", "\n", "1. Receive `data` tensor with `NaN` at positions to sample\n", "2. Use `sampling_ctx.channel_index` to select which channel's parameters\n", "3. Use `sampling_ctx.repetition_idx` to select which repetition's parameters\n", "4. Sample from the distribution (or take mode for MPE)\n", "5. Write samples **in-place** to `data[:, self.scope.query]`\n", "\n", "\n", "## 3. Implementing a Sum Module\n", "\n", "Sum modules compute weighted mixtures: $p(x) = \\sum_i w_i \\cdot p_i(x)$.\n", "\n", "We now implement a custom sum module from scratch to demonstrate what is necessary\n", "to extend SPFlow with new sum-like operations. Our `NoisySum` adds noise during\n", "training for regularization:\n", "\n", "$$\\log p(x) = \\log \\sum_i w_i \\cdot p_i(x) + \\epsilon, \\quad \\epsilon \\sim \\mathcal{N}(0, \\sigma^2)$$\n", "\n", "To implement any sum module, you need:\n", "1. Weight parameters (stored as logits for unconstrained optimization)\n", "2. `log_likelihood()` using logsumexp for numerical stability \n", "3. `sample()` that selects input channels based on weights\n", "4. `feature_to_scope` property mapping features to scopes\n", "\n", "### Example: NoisySum\n", "\n", "A sum module that adds Gaussian noise during training." ] }, { "cell_type": "code", "execution_count": 12, "id": "98b465d4", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "from spflow.modules.module import Module\n", "from spflow.modules.module_shape import ModuleShape\n", "from spflow.modules.ops.cat import Cat\n", "from spflow.utils.cache import Cache, cached\n", "from spflow.utils.projections import proj_convex_to_real\n", "from spflow.utils.sampling_context import SamplingContext, init_default_sampling_context\n", "\n", "\n", "class NoisySum(Module):\n", " \"\"\"Sum module with additive noise during training.\n", " \n", " Adds Gaussian noise to log-likelihoods during training for regularization.\n", " Deterministic during evaluation.\n", " \"\"\"\n", "\n", " def __init__(self, inputs, out_channels: int, noise_std: float = 0.1, num_repetitions: int = 1):\n", " super().__init__()\n", "\n", " # Handle single module or list of modules\n", " if isinstance(inputs, list):\n", " self.inputs = Cat(inputs, dim=2) if len(inputs) > 1 else inputs[0]\n", " else:\n", " self.inputs = inputs\n", "\n", " self.scope = self.inputs.scope\n", " self.noise_std = noise_std\n", "\n", " # Shape computation\n", " in_shape = self.inputs.out_shape\n", " self.in_shape = in_shape\n", " self.out_shape = ModuleShape(in_shape.features, out_channels, num_repetitions)\n", "\n", " # Weight shape: (features, in_channels, out_channels, repetitions)\n", " self._weights_shape = (\n", " in_shape.features, in_shape.channels, out_channels, num_repetitions\n", " )\n", "\n", " # Initialize weights randomly (store as logits for unconstrained optimization)\n", " weights = torch.rand(self._weights_shape) + 1e-8\n", " weights /= weights.sum(dim=1, keepdim=True)\n", " self.logits = nn.Parameter(proj_convex_to_real(weights))\n", "\n", " @property\n", " def feature_to_scope(self) -> np.ndarray:\n", " return self.inputs.feature_to_scope\n", "\n", " @property\n", " def log_weights(self) -> Tensor:\n", " \"\"\"Log-normalized weights via log_softmax.\"\"\"\n", " return torch.nn.functional.log_softmax(self.logits, dim=1)\n", "\n", " @property\n", " def weights(self) -> Tensor:\n", " \"\"\"Normalized weights via softmax.\"\"\"\n", " return torch.nn.functional.softmax(self.logits, dim=1)\n", "\n", " @cached\n", " def log_likelihood(self, data: Tensor, cache: Cache | None = None) -> Tensor:\n", " # Input shape: (batch, features, in_channels, reps)\n", " ll = self.inputs.log_likelihood(data, cache=cache)\n", "\n", " # Expand for out_channels: (batch, features, in_channels, 1, reps)\n", " ll = ll.unsqueeze(3)\n", "\n", " # Weights: (1, features, in_channels, out_channels, reps)\n", " log_w = self.log_weights.unsqueeze(0)\n", "\n", " # Weighted sum via logsumexp over in_channels (dim=2)\n", " result = torch.logsumexp(ll + log_w, dim=2)\n", "\n", " # Add noise during training only\n", " if self.training:\n", " noise = torch.randn_like(result) * self.noise_std\n", " result = result + noise\n", "\n", " return result # Shape: (batch, features, out_channels, reps)\n", "\n", " def sample(\n", " self,\n", " num_samples: int | None = None,\n", " data: Tensor | None = None,\n", " is_mpe: bool = False,\n", " cache: Cache | None = None,\n", " sampling_ctx: SamplingContext | None = None,\n", " ) -> Tensor:\n", " data = self._prepare_sample_data(num_samples, data)\n", " sampling_ctx = init_default_sampling_context(sampling_ctx, data.shape[0], data.device)\n", "\n", " logits = self.logits[..., 0]\n", " logits = logits.unsqueeze(0).expand(data.shape[0], -1, -1, -1)\n", "\n", " # Gather logits for selected out_channels\n", " idxs = sampling_ctx.channel_index.unsqueeze(-1).unsqueeze(-1)\n", " idxs = idxs.expand(-1, -1, logits.shape[2], -1)\n", " logits = logits.gather(dim=3, index=idxs).squeeze(3)\n", "\n", " # Select input channels: either argmax (MPE) or sample\n", " if is_mpe:\n", " new_channels = logits.argmax(dim=-1)\n", " else:\n", " new_channels = torch.distributions.Categorical(logits=logits).sample()\n", "\n", " sampling_ctx.channel_index = new_channels\n", " self.inputs.sample(data=data, is_mpe=is_mpe, cache=cache, sampling_ctx=sampling_ctx)\n", " return data\n", "\n", " def marginalize(self, marg_rvs: list[int], prune: bool = True, cache=None):\n", " mutual = set(self.scope.query) & set(marg_rvs)\n", " if len(mutual) == len(self.scope.query):\n", " return None\n", "\n", " marg_input = self.inputs.marginalize(marg_rvs, prune=prune, cache=cache)\n", " if marg_input is None:\n", " return None\n", "\n", " return NoisySum(\n", " inputs=marg_input,\n", " out_channels=self.out_shape.channels,\n", " noise_std=self.noise_std,\n", " num_repetitions=self.out_shape.repetitions,\n", " )" ] }, { "cell_type": "code", "execution_count": 13, "id": "6ccb1cfa", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Training: outputs differ = True\n", "Eval: outputs identical = True\n" ] } ], "source": [ "# Test NoisySum\n", "\n", "leaf = NoisyNormal(scope=Scope([0]), out_channels=4)\n", "noisy_sum = NoisySum(inputs=leaf, out_channels=2, noise_std=0.5)\n", "\n", "data = torch.randn(10, 1)\n", "\n", "# Training mode: noise added\n", "noisy_sum.train()\n", "ll1 = noisy_sum.log_likelihood(data)\n", "ll2 = noisy_sum.log_likelihood(data)\n", "print(f\"Training: outputs differ = {not torch.allclose(ll1, ll2)}\")\n", "\n", "# Eval mode: deterministic\n", "noisy_sum.eval()\n", "ll1 = noisy_sum.log_likelihood(data)\n", "ll2 = noisy_sum.log_likelihood(data)\n", "print(f\"Eval: outputs identical = {torch.allclose(ll1, ll2)}\")\n", "\n" ] }, { "cell_type": "markdown", "id": "cbd705ab", "metadata": {}, "source": [ "**Key points:**\n", "- Use `@cached` decorator to enable caching for sampling and EM\n", "- Weights have shape `(features, in_channels, out_channels, repetitions)`\n", "- Use `proj_convex_to_real()` to convert probabilities to unconstrained logits\n", "\n", "**How Sum Sampling Works:**\n", "\n", "Sum modules **select paths** through the DAG—they don't generate samples.\n", "\n", "1. Receive current `sampling_ctx.channel_index` from parent\n", "2. Gather logits for those specific output channels\n", "3. Sample from `Categorical(logits)` (or `argmax` for MPE)\n", "4. **Update** `sampling_ctx.channel_index` with selected child indices\n", "5. Call `self.inputs.sample(...)` to continue traversal\n", "\n", "The key code pattern:\n", "```python\n", "# Select which input channel to use for each sample\n", "if is_mpe:\n", " new_channel_index = torch.argmax(logits, dim=-1)\n", "else:\n", " new_channel_index = Categorical(logits=logits).sample()\n", "\n", "# Update context and delegate to children\n", "sampling_ctx.channel_index = new_channel_index\n", "self.inputs.sample(data=data, sampling_ctx=sampling_ctx, ...)\n", "```\n", "\n", "## 4. Implementing a Product Module\n", "\n", "Product modules compute joint distributions: $p(x_1, x_2) = p(x_1) \\cdot p(x_2)$.\n", "\n", "To implement a product module from scratch, you need:\n", "1. Concatenate multiple inputs via `Cat` (along feature dimension)\n", "2. `log_likelihood()` that sums log-probs across features\n", "3. `sample()` that delegates to the input with expanded context\n", "4. `feature_to_scope` that joins all input scopes\n", "\n", "### Example: NoisyProduct\n", "\n", "A product module that adds Gaussian noise during training for regularization.\n", "This demonstrates how to add training-time behavior to a product.\n", "\n", "$$\\log p(x) = \\sum_j \\log p_j(x_j) + \\epsilon, \\quad \\epsilon \\sim \\mathcal{N}(0, \\sigma^2)$$" ] }, { "cell_type": "code", "execution_count": 14, "id": "e16a0f66", "metadata": {}, "outputs": [], "source": [ "class NoisyProduct(Module):\n", " \"\"\"Product module with additive Gaussian noise during training.\n", " \n", " Adds noise to log-likelihoods during training for regularization.\n", " Deterministic during evaluation.\n", " \"\"\"\n", "\n", " def __init__(self, inputs, noise_std: float = 0.1):\n", " super().__init__()\n", "\n", " # Handle single module or list of modules\n", " if isinstance(inputs, list):\n", " self.inputs = Cat(inputs, dim=1) if len(inputs) > 1 else inputs[0]\n", " else:\n", " self.inputs = inputs\n", "\n", " self.scope = self.inputs.scope\n", " self.noise_std = noise_std\n", "\n", " # Shape: product reduces features to 1\n", " in_shape = self.inputs.out_shape\n", " self.in_shape = in_shape\n", " self.out_shape = ModuleShape(1, in_shape.channels, in_shape.repetitions)\n", "\n", " @property\n", " def feature_to_scope(self) -> np.ndarray:\n", " # Join all input scopes into a single scope per repetition\n", " out = []\n", " for r in range(self.out_shape.repetitions):\n", " joined = Scope.join_all(self.inputs.feature_to_scope[:, r])\n", " out.append(np.array([[joined]]))\n", " return np.concatenate(out, axis=1)\n", "\n", " @cached\n", " def log_likelihood(self, data: Tensor, cache: Cache | None = None) -> Tensor:\n", " # Get input log-likelihoods: (batch, features, channels, reps)\n", " ll = self.inputs.log_likelihood(data, cache=cache)\n", "\n", " # Product = sum in log-space, reduce over features (dim=1)\n", " result = torch.sum(ll, dim=1, keepdim=True)\n", "\n", " # Add noise during training only\n", " if self.training:\n", " noise = torch.randn_like(result) * self.noise_std\n", " result = result + noise\n", "\n", " return result # Shape: (batch, 1, channels, reps)\n", "\n", " def sample(\n", " self,\n", " num_samples: int | None = None,\n", " data: Tensor | None = None,\n", " is_mpe: bool = False,\n", " cache: Cache | None = None,\n", " sampling_ctx: SamplingContext | None = None,\n", " ) -> Tensor:\n", " data = self._prepare_sample_data(num_samples, data)\n", " sampling_ctx = init_default_sampling_context(sampling_ctx, data.shape[0], data.device)\n", "\n", " # Expand context to match input feature count\n", " in_features = self.inputs.out_shape.features\n", " channel_index = sampling_ctx.channel_index.expand(-1, in_features)\n", " mask = sampling_ctx.mask.expand(-1, in_features)\n", " sampling_ctx.update(channel_index=channel_index, mask=mask)\n", "\n", " self.inputs.sample(data=data, is_mpe=is_mpe, cache=cache, sampling_ctx=sampling_ctx)\n", " return data\n", "\n", " def marginalize(self, marg_rvs: list[int], prune: bool = True, cache=None):\n", " mutual = set(self.scope.query) & set(marg_rvs)\n", " if len(mutual) == len(self.scope.query):\n", " return None\n", "\n", " marg_input = self.inputs.marginalize(marg_rvs, prune=prune, cache=cache)\n", " if marg_input is None:\n", " return None\n", "\n", " if prune and marg_input.out_shape.features == 1:\n", " return marg_input\n", "\n", " return NoisyProduct(inputs=marg_input, noise_std=self.noise_std)" ] }, { "cell_type": "code", "execution_count": 15, "id": "c034fcc9", "metadata": { "lines_to_next_cell": 2 }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Product scope: (0, 1)\n", "Training: outputs differ = True\n", "Eval: outputs identical = True\n", "Samples shape: torch.Size([100, 2])\n" ] } ], "source": [ "# Test NoisyProduct\n", "from spflow.modules.leaves import Normal\n", "\n", "leaf1 = Normal(scope=Scope([0]), out_channels=2)\n", "leaf2 = Normal(scope=Scope([1]), out_channels=2)\n", "noisy_prod = NoisyProduct(inputs=[leaf1, leaf2], noise_std=0.5)\n", "\n", "# Verify feature_to_scope matches expected joined scope\n", "print(f\"Product scope: {noisy_prod.scope}\")\n", "assert len(noisy_prod.scope) == 2\n", "\n", "data = torch.randn(5, 2)\n", "\n", "# Training mode: noise is added (outputs differ each call)\n", "noisy_prod.train()\n", "ll1 = noisy_prod.log_likelihood(data)\n", "ll2 = noisy_prod.log_likelihood(data)\n", "print(f\"Training: outputs differ = {not torch.allclose(ll1, ll2)}\")\n", "\n", "# Eval mode: deterministic (outputs identical)\n", "noisy_prod.eval()\n", "ll1 = noisy_prod.log_likelihood(data)\n", "ll2 = noisy_prod.log_likelihood(data)\n", "print(f\"Eval: outputs identical = {torch.allclose(ll1, ll2)}\")\n", "\n", "# Verify sampling works and produces correct shape\n", "samples = noisy_prod.sample(num_samples=100)\n", "print(f\"Samples shape: {samples.shape}\")\n", "assert samples.shape == (100, 2)" ] }, { "cell_type": "markdown", "id": "fe82fbd8", "metadata": {}, "source": [ "**Key points:**\n", "- Products sum log-likelihoods across features (dim=1)\n", "- `Cat(inputs, dim=1)` concatenates inputs along the feature dimension\n", "- Use `self.training` to differentiate train/eval behavior\n", "\n", "**How Product Sampling Works:**\n", "\n", "Products represent factorization: $p(X_1, X_2) = p(X_1) \\cdot p(X_2)$.\n", "They **expand** the sampling context but don't select paths.\n", "\n", "1. Inputs have **disjoint scopes** (different random variables)\n", "2. Expand `channel_index` from `(batch, 1)` to `(batch, num_input_features)`\n", "3. Expand `mask` similarly\n", "4. Pass expanded context to children—no selection happens\n", "\n", "The key code pattern:\n", "```python\n", "# Expand context to match number of input features\n", "channel_index = sampling_ctx.channel_index.expand(-1, self.inputs.out_shape.features)\n", "mask = sampling_ctx.mask.expand(-1, self.inputs.out_shape.features)\n", "sampling_ctx.update(channel_index=channel_index, mask=mask)\n", "\n", "# Delegate to children\n", "self.inputs.sample(data=data, sampling_ctx=sampling_ctx, ...)\n", "```\n", "\n", "Products have **no learnable parameters**—they are purely structural.\n", "\n", "## 5. Implementing a Split Module\n", "\n", "Split modules partition an input module into multiple groups. They provide different *views* of the same input.\n", "\n", "### Example: RandomSplit\n", "\n", "A `RandomSplit` assigns features to groups randomly (fixed at initialization)." ] }, { "cell_type": "code", "execution_count": 16, "id": "849042f7", "metadata": {}, "outputs": [], "source": [ "from spflow.modules.ops.split import Split\n", "from spflow.modules.module import Module\n", "\n", "\n", "class RandomSplit(Split):\n", " \"\"\"Split features into groups via random assignment.\"\"\"\n", "\n", " def __init__(self, inputs: Module, num_splits: int = 2, seed: int = 42):\n", " super().__init__(inputs=inputs, dim=1, num_splits=num_splits)\n", "\n", " # Randomly assign each feature to a split\n", " gen = torch.Generator().manual_seed(seed)\n", " num_features = inputs.out_shape.features\n", " assignments = torch.randint(0, num_splits, (num_features,), generator=gen)\n", "\n", " # Create boolean masks for each split\n", " self.split_masks = [assignments == i for i in range(num_splits)]\n", "\n", " # Store assignments for merge_split_indices\n", " self.register_buffer(\"_assignments\", assignments)\n", "\n", " @property\n", " def feature_to_scope(self):\n", " scopes = self.inputs.feature_to_scope\n", " return [\n", " [scopes[j] for j in range(len(scopes)) if self.split_masks[i][j]]\n", " for i in range(self.num_splits)\n", " ]\n", "\n", " @cached\n", " def log_likelihood(self, data, cache=None):\n", " lls = self.inputs.log_likelihood(data, cache=cache)\n", " return [lls[:, mask, ...] for mask in self.split_masks]\n", "\n", "\n", " def merge_split_indices(self, *split_indices: Tensor) -> Tensor:\n", " batch_size = split_indices[0].shape[0]\n", " num_features = self.inputs.out_shape.features\n", " # Create output tensor\n", " result = torch.zeros(batch_size, num_features, dtype=split_indices[0].dtype, device=split_indices[0].device)\n", " # Track position within each split\n", " split_positions = [0] * self.num_splits\n", " # Scatter indices back to original positions\n", " for feature_idx in range(num_features):\n", " split_idx = self._assignments[feature_idx].item()\n", " pos = split_positions[split_idx]\n", " result[:, feature_idx] = split_indices[split_idx][:, pos]\n", " split_positions[split_idx] += 1\n", " return result" ] }, { "cell_type": "code", "execution_count": 17, "id": "f5b66811", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Split 0 shape: torch.Size([5, 5, 2, 1])\n", "Split 1 shape: torch.Size([5, 1, 2, 1])\n" ] } ], "source": [ "# Test RandomSplit\n", "leaf = Normal(scope=Scope(list(range(6))), out_channels=2)\n", "split = RandomSplit(inputs=leaf, num_splits=2, seed=123)\n", "\n", "data = torch.randn(5, 6)\n", "lls = split.log_likelihood(data)\n", "\n", "for i, ll in enumerate(lls):\n", " print(f\"Split {i} shape: {ll.shape}\")\n", "\n" ] }, { "cell_type": "markdown", "id": "fa390c6d", "metadata": {}, "source": [ "**Key points:**\n", "- Split modules return a *list* of tensors (one per split)\n", "- `self.inputs` is a single module in Split (unlike Sum/Product which can wrap multiple)\n", "- `feature_to_scope` must return a list of scope lists, one per split\n", "\n", "## 6. Testing Your Module\n", "\n", "SPFlow provides test utilities in `tests/utils/`. Here's a minimal test pattern:" ] }, { "cell_type": "code", "execution_count": 18, "id": "81fc02a3", "metadata": { "lines_to_next_cell": 2 }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "All tests passed!\n" ] } ], "source": [ "def test_noisy_normal_leaf():\n", " \"\"\"Test basic functionality of NoisyNormal leaf.\"\"\"\n", " leaf = NoisyNormal(scope=Scope([0]), out_channels=4, num_repetitions=1, noise_std=0.1)\n", "\n", " # Check shapes\n", " assert leaf.out_shape.features == 1\n", " assert leaf.out_shape.channels == 4\n", " assert leaf.out_shape.repetitions == 1\n", "\n", " # Check log-likelihood\n", " data = torch.randn(10, 1)\n", " leaf.eval() # Deterministic mode for testing\n", " ll = leaf.log_likelihood(data)\n", " assert ll.shape == (10, 1, 4, 1)\n", " assert torch.isfinite(ll).all()\n", "\n", " # Check sampling\n", " samples = leaf.sample(num_samples=100)\n", " assert samples.shape == (100, 1)\n", "\n", "test_noisy_normal_leaf()\n", "print(\"All tests passed!\")" ] }, { "cell_type": "markdown", "id": "a174a3f2", "metadata": {}, "source": [ "> **Contributing to SPFlow:** If you want to contribute your module to the SPFlow repository, unit tests are **required**. See `tests/` for examples and use `pytest` with parametrization for thorough coverage.\n", "\n", "## 7. Reference Implementations\n", "\n", "For complete examples, see:\n", "\n", "| Module Type | Reference File |\n", "|------------|----------------|\n", "| Leaf | `spflow/modules/leaves/normal.py` (~115 lines) |\n", "| Sum | `spflow/modules/sums/sum.py` |\n", "| Product | `spflow/modules/products/product.py` (~210 lines) |\n", "| Split | `spflow/modules/ops/split.py` |" ] } ], "metadata": { "jupytext": { "cell_metadata_filter": "-all", "encoding": "# coding: utf-8", "executable": "/usr/bin/env python", "main_language": "python", "notebook_metadata_filter": "-all" }, "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.12" } }, "nbformat": 4, "nbformat_minor": 5 }