Source code for spflow.modules.rat.rat_spn

"""Random and Tensorized Sum-Product Networks (RAT-SPNs) implementation.

RAT-SPNs provide a principled approach to building deep probabilistic models
through randomized circuit construction, combining interpretability with
expressiveness through tensorized operations.

Reference:
    Peharz, R., et al. (2020). "Random Sum-Product Networks: A Simple and
    Effective Approach to Probabilistic Deep Learning." NeurIPS 2020.
"""

from __future__ import annotations

import numpy as np
import torch

from spflow.exceptions import InvalidParameterError, UnsupportedOperationError
from spflow.interfaces.classifier import Classifier
from spflow.meta.data.scope import Scope
from spflow.modules.module import Module
from spflow.modules.leaves.leaf import LeafModule
from spflow.modules.ops.split import Split, SplitMode
from spflow.modules.ops.split_consecutive import SplitConsecutive
from spflow.modules.ops.split_interleaved import SplitInterleaved
from spflow.modules.products.elementwise_product import ElementwiseProduct
from spflow.modules.products.outer_product import OuterProduct
from spflow.modules.rat.factorize import Factorize
from spflow.modules.sums.repetition_mixing_layer import RepetitionMixingLayer
from spflow.modules.sums.sum import Sum
from spflow.utils.cache import Cache, cached
from spflow.utils.inference import log_posterior
from spflow.utils.sampling_context import SamplingContext, init_default_sampling_context


[docs] class RatSPN(Module, Classifier): """Random and Tensorized Sum-Product Network (RAT-SPN). Scalable deep probabilistic model with randomized circuit construction. Consists of alternating sum (region) and product (partition) layers that recursively partition input space. Random construction prevents overfitting while maintaining tractable exact inference. Attributes: leaf_modules (list[LeafModule]): Leaf distribution modules. n_root_nodes (int): Number of root sum nodes. n_region_nodes (int): Number of sum nodes per region. depth (int): Number of partition/region layers. num_repetitions (int): Number of parallel circuit instances. scope (Scope): Combined scope of all leaf modules. Reference: Peharz, R., et al. (2020). "Random Sum-Product Networks: A Simple and Effective Approach to Probabilistic Deep Learning." NeurIPS 2020. """
[docs] def __init__( self, leaf_modules: list[LeafModule], n_root_nodes: int, n_region_nodes: int, num_repetitions: int, depth: int, outer_product: bool | None = False, split_mode: SplitMode | None = None, num_splits: int | None = 2, ) -> None: """Initialize RAT-SPN with specified architecture parameters. Creates a Random and Tensorized SPN by recursively constructing layers of sum and product nodes. Circuit structure is fixed after initialization. Args: leaf_modules (list[LeafModule]): Leaf distributions forming the base layer. n_root_nodes (int): Number of root sum nodes in final mixture. n_region_nodes (int): Number of sum nodes in each region layer. num_repetitions (int): Number of parallel circuit instances. depth (int): Number of partition/region layers. outer_product (bool | None, optional): Use outer product instead of elementwise product for partitions. Defaults to False. split_mode (SplitMode | None, optional): Split configuration. Use SplitMode.consecutive() or SplitMode.interleaved(). Defaults to SplitMode.consecutive(num_splits) if not specified. num_splits (int | None, optional): Number of splits in each partition. Must be at least 2. Defaults to 2. Raises: ValueError: If architectural parameters are invalid. """ super().__init__() self.n_root_nodes = n_root_nodes self.n_region_nodes = n_region_nodes self.n_leaf_nodes = leaf_modules[0].out_shape.channels self.leaf_modules = leaf_modules self.depth = depth self.num_repetitions = num_repetitions self.outer_product = outer_product self.num_splits = num_splits self.split_mode = split_mode if split_mode is not None else SplitMode.consecutive(num_splits) self.scope = Scope.join_all([leaf.scope for leaf in leaf_modules]) if n_root_nodes < 1: raise InvalidParameterError( f"Specified value of 'n_root_nodes' must be at least 1, but is {n_root_nodes}." ) if n_region_nodes < 1: raise InvalidParameterError( f"Specified value for 'n_region_nodes' must be at least 1, but is {n_region_nodes}." ) if self.n_leaf_nodes < 1: raise InvalidParameterError( f"Specified value for 'n_leaf_nodes' must be at least 1, but is {self.n_leaf_nodes}." ) if self.num_splits < 2: raise InvalidParameterError( f"Specified value for 'num_splits' must be at least 2, but is {self.num_splits}." ) self.create_spn() # Shape computation: delegate to root node self.in_shape = self.root_node.in_shape self.out_shape = self.root_node.out_shape
[docs] def create_spn(self): """Create the RAT-SPN architecture. Builds the RAT-SPN circuit structure from bottom to top based on the provided architectural parameters. Architecture is constructed recursively from leaves to root using alternating layers of sum and product nodes, and the final structure depends on depth and branching parameters. """ if self.outer_product: product_layer = OuterProduct else: product_layer = ElementwiseProduct # Factorize the leaves modules fac_layer = Factorize( inputs=self.leaf_modules, depth=self.depth, num_repetitions=self.num_repetitions ) depth = self.depth root = None for i in range(depth): # Create the lowest layer with the factorized leaves modules as input if i == 0: out_prod = product_layer(inputs=self.split_mode.create(fac_layer)) if depth == 1: sum_layer = Sum( inputs=out_prod, out_channels=self.n_root_nodes, num_repetitions=self.num_repetitions ) else: sum_layer = Sum( inputs=out_prod, out_channels=self.n_region_nodes, num_repetitions=self.num_repetitions, ) root = sum_layer # Special case for the last intermediate layer elif i == depth - 1: out_prod = product_layer(self.split_mode.create(root)) sum_layer = Sum( inputs=out_prod, out_channels=self.n_root_nodes, num_repetitions=self.num_repetitions ) root = sum_layer # Create the intermediate layers else: out_prod = product_layer(self.split_mode.create(root)) sum_layer = Sum( inputs=out_prod, out_channels=self.n_region_nodes, num_repetitions=self.num_repetitions ) root = sum_layer # MixingLayer: Sums over repetitions root = RepetitionMixingLayer( inputs=root, out_channels=self.n_root_nodes, num_repetitions=self.num_repetitions ) # root node: Sum over all out_channels if self.n_root_nodes > 1: self.root_node = Sum(inputs=root, out_channels=1, num_repetitions=1) else: self.root_node = root
@property def n_out(self) -> int: return 1 @property def feature_to_scope(self) -> np.ndarray: return self.root_node.feature_to_scope @property def scopes_out(self) -> list[Scope]: return self.root_node.scopes_out
[docs] @cached def log_likelihood( self, data: torch.Tensor, cache: Cache | None = None, ) -> torch.Tensor: """Compute log likelihood for RAT-SPN. Args: data: Input data tensor. cache: Optional cache dictionary for caching intermediate results. Returns: Log-likelihood values. """ ll = self.root_node.log_likelihood( data, cache=cache, ) return ll
[docs] def log_posterior( self, data: torch.Tensor, cache: Cache | None = None, ) -> torch.Tensor: """Compute log-posterior probabilities for multi-class models. Args: data: Input data tensor. cache: Optional cache dictionary for caching intermediate results. Returns: Log-posterior probabilities. Raises: UnsupportedOperationError: If model has only one root node (single class). """ if self.n_root_nodes <= 1: raise UnsupportedOperationError( "Posterior can only be computed for models with multiple classes." ) ll_y = self.root_node.log_weights # shape: (1, n_root_nodes, 1, 1) ll_y = ll_y.squeeze().view(1, -1) # shape: (1, n_root_nodes) ll = self.root_node.inputs.log_likelihood( data, cache=cache, ) # shape: (batch_size,1 , n_root_nodes) ll = ll.squeeze(-1).squeeze( 1 ) # remove repetition and feature dimensions -> shape: (batch_size, n_root_nodes) return log_posterior(log_likelihood=ll, log_prior=ll_y)
[docs] def predict_proba(self, data: torch.Tensor): """Classify input data using RAT-SPN. Args: data: Input data tensor. Returns: Predicted class labels. """ log_post = self.log_posterior(data) return torch.exp(log_post)
[docs] def sample( self, num_samples: int | None = None, data: torch.Tensor | None = None, is_mpe: bool = False, cache: Cache | None = None, sampling_ctx: SamplingContext | None = None, ) -> torch.Tensor: """Generate samples from the RAT-SPN. Args: num_samples: Number of samples to generate. data: Data tensor with NaN values to fill with samples. is_mpe: Whether to perform maximum a posteriori estimation. cache: Optional cache dictionary. sampling_ctx: Optional sampling context. Returns: Sampled values. """ # Handle num_samples case (create empty data tensor) if data is None: if num_samples is None: num_samples = 1 data = torch.full((num_samples, len(self.scope.query)), torch.nan, device=self.device) # if no sampling context is provided, initialize a context by sampling from the root node if sampling_ctx is None and self.n_root_nodes > 1: sampling_ctx = init_default_sampling_context(sampling_ctx, data.shape[0], data.device) logits = self.root_node.logits if logits.shape != (1, self.n_root_nodes, 1): raise InvalidParameterError( f"Expected logits shape (1, {self.n_root_nodes}, 1), but got {logits.shape}" ) logits = logits.squeeze(-1) logits = logits.unsqueeze(0).expand(data.shape[0], -1, -1) # shape [b ,1, n_root_nodes] if is_mpe: sampling_ctx.channel_index = torch.argmax(logits, dim=-1) else: sampling_ctx.channel_index = torch.distributions.Categorical(logits=logits).sample() else: sampling_ctx = init_default_sampling_context(sampling_ctx, data.shape[0], data.device) # if the model only has one root node, we can directly sample from the mixing layer if self.n_root_nodes > 1: sample_root = self.root_node.inputs else: sample_root = self.root_node return sample_root.sample( data=data, is_mpe=is_mpe, cache=cache, sampling_ctx=sampling_ctx, )
[docs] def expectation_maximization( self, data: torch.Tensor, cache: Cache | None = None, ) -> None: """Perform expectation-maximization step. Args: data: Input data tensor. cache: Optional cache dictionary. """ self.root_node.expectation_maximization(data, cache=cache)
[docs] def maximum_likelihood_estimation( self, data: torch.Tensor, weights: torch.Tensor | None = None, cache: Cache | None = None, ) -> None: """Update parameters via maximum likelihood estimation. Args: data: Input data tensor. weights: Optional sample weights. cache: Optional cache dictionary. """ self.root_node.maximum_likelihood_estimation( data, weights=weights, cache=cache, )
[docs] def marginalize( self, marg_rvs: list[int], prune: bool = True, cache: Cache | None = None, ) -> Module | None: """Marginalize out specified random variables. Args: marg_rvs: List of random variables to marginalize. prune: Whether to prune the module. cache: Optional cache dictionary. Returns: Marginalized module or None. """ return self.root_node.marginalize(marg_rvs, prune=prune, cache=cache)