Source code for spflow.modules.ops.split

"""Split operations for tensor partitioning in probabilistic circuits.

Provides base classes and implementations for splitting tensors along
dimensions. Essential for RAT-SPNs and other architectures requiring
systematic tensor partitioning.
"""

from __future__ import annotations
import numpy as np

from abc import abstractmethod, ABC
from typing import Any, Dict, Optional

from torch import Tensor, nn

from spflow.exceptions import InvalidParameterError
from spflow.meta.data import Scope
from spflow.modules.module import Module
from spflow.modules.module_shape import ModuleShape
from spflow.utils.sampling_context import (
    SamplingContext,
    init_default_sampling_context,
)


[docs] class SplitMode: """Configuration for split operations. Factory class for creating split configurations. Use the class methods to create split configurations that can be passed to modules. Example: >>> layer = EinsumLayer(inputs=leaf, num_repetitions=3, out_channels=10, split_mode=SplitMode.interleaved(num_splits=3)) >>> layer = LinsumLayer(inputs=leaf, out_channels=10, split_mode=SplitMode.consecutive(num_splits=2)) >>> layer = LinsumLayer(inputs=leaf, out_channels=10, split_mode=SplitMode.by_index(indices=[[0,1], [2,3]])) """
[docs] def __init__(self, split_type: str, num_splits: int = 2, indices: list[list[int]] | None = None): """Initialize split mode configuration. Args: split_type: Type of split ('consecutive', 'interleaved', or 'by_index'). num_splits: Number of parts to split into. indices: For 'by_index' type, the feature indices for each split. """ if split_type not in ("consecutive", "interleaved", "by_index"): raise InvalidParameterError( f"split_type must be 'consecutive', 'interleaved', or 'by_index', got '{split_type}'" ) if split_type != "by_index" and num_splits < 2: raise InvalidParameterError(f"num_splits must be at least 2, got {num_splits}") if split_type == "by_index" and indices is None: raise InvalidParameterError("indices must be provided for 'by_index' split type") self._split_type = split_type self._num_splits = num_splits self._indices = indices
@property def num_splits(self) -> int: """Number of splits.""" return self._num_splits @property def split_type(self) -> str: """Type of split ('consecutive', 'interleaved', or 'by_index').""" return self._split_type @property def indices(self) -> list[list[int]] | None: """Feature indices for 'by_index' split type.""" return self._indices
[docs] @classmethod def consecutive(cls, num_splits: int = 2) -> "SplitMode": """Create a consecutive split configuration. Splits features into consecutive chunks: [0,1,2,3] -> [0,1], [2,3]. Args: num_splits: Number of parts to split into. Returns: SplitMode configuration for consecutive splitting. """ return cls("consecutive", num_splits)
[docs] @classmethod def interleaved(cls, num_splits: int = 2) -> "SplitMode": """Create an interleaved split configuration. Splits features using modulo: [0,1,2,3] -> [0,2], [1,3]. Args: num_splits: Number of parts to split into. Returns: SplitMode configuration for interleaved splitting. """ return cls("interleaved", num_splits)
[docs] @classmethod def by_index(cls, indices: list[list[int]]) -> "SplitMode": """Create a split configuration with explicit feature indices. Splits features according to specified indices. Each inner list contains the feature indices for that split. Example: >>> SplitMode.by_index([[0, 1, 4], [2, 3, 5, 6, 7]]) SplitMode.by_index(indices=[[0, 1, 4], [2, 3, 5, 6, 7]]) Args: indices: List of lists specifying feature indices for each split. All features must be covered exactly once. Returns: SplitMode configuration for index-based splitting. """ return cls("by_index", num_splits=len(indices), indices=indices)
[docs] def create(self, inputs: Module) -> "Split": """Create a Split module with this configuration. Args: inputs: Input module to split. Returns: Split module configured according to this SplitMode. """ # Import here to avoid circular imports from spflow.modules.ops.split_consecutive import SplitConsecutive from spflow.modules.ops.split_interleaved import SplitInterleaved from spflow.modules.ops.split_by_index import SplitByIndex if self._split_type == "consecutive": return SplitConsecutive(inputs, num_splits=self._num_splits) elif self._split_type == "interleaved": return SplitInterleaved(inputs, num_splits=self._num_splits) else: # by_index return SplitByIndex(inputs, indices=self._indices)
def __repr__(self) -> str: if self._split_type == "by_index": return f"SplitMode.by_index(indices={self._indices})" return f"SplitMode.{self._split_type}(num_splits={self._num_splits})"
[docs] class Split(Module, ABC): """Abstract base class for tensor splitting operations. Splits input tensors along specified dimensions. Concrete implementations must provide feature_to_scope property. Attributes: inputs (nn.ModuleList): Single input module to split. dim (int): Dimension along which to split (0=batch, 1=feature, 2=channel). num_splits (int): Number of splits to create. scope (Scope): Variable scope inherited from input. """
[docs] def __init__(self, inputs: Module, dim: int = 1, num_splits: int | None = 2): """Initialize split operation. Args: inputs: Input module to split. dim: Dimension along which to split (0=batch, 1=feature, 2=channel). num_splits: Number of parts to split into. """ super().__init__() if not isinstance(inputs, Module): raise InvalidParameterError(f"'{self.__class__.__name__}' requires a single Module as input.") self.inputs = inputs self.dim = dim self.num_splits = num_splits self.scope = self.inputs.scope # Shape computation in_shape = self.inputs.out_shape self.in_shape = in_shape self.out_shape = ModuleShape(in_shape.features, in_shape.channels, in_shape.repetitions)
[docs] def get_out_shapes(self, event_shape): """Get output shapes for each split based on input event shape. Args: event_shape: Shape of the input event tensor. Returns: List of tuples representing output shapes for each split. """ split_size = event_shape[self.dim] quotient = split_size // self.num_splits remainder = split_size % self.num_splits if self.dim == 0: if remainder == 0: return [(quotient, event_shape[1])] * self.num_splits else: return [(quotient, event_shape[1])] * (self.num_splits - 1) + [(remainder, event_shape[1])] else: if remainder == 0: return [(event_shape[0], quotient)] * self.num_splits else: return [(event_shape[0], quotient)] * (self.num_splits - 1) + [(event_shape[1], remainder)]
@property @abstractmethod def feature_to_scope(self) -> np.ndarray: pass
[docs] @abstractmethod def merge_split_indices(self, *split_indices: Tensor) -> Tensor: """Merge per-split channel indices back to original feature layout. This method takes channel indices for each split and combines them into indices matching the original (unsplit) feature layout. Used by parent modules (like EinsumLayer) during sampling. Args: *split_indices: Channel index tensors for each split, shape (batch, features_per_split). Returns: Merged indices matching the input module's feature layout, shape (batch, total_features). """ pass
[docs] def sample( self, num_samples: int | None = None, data: Tensor | None = None, is_mpe: bool = False, cache: Optional[Dict[str, Any]] = None, sampling_ctx: SamplingContext | None = None, ) -> Tensor: """Generate samples by delegating to input module. Args: num_samples: Number of samples to generate. data: Existing data tensor to modify. is_mpe: Whether to perform most probable explanation. cache: Cache dictionary for intermediate results. sampling_ctx: Sampling context for controlling sample generation. Returns: Tensor containing the generated samples. """ # Prepare data tensor data = self._prepare_sample_data(num_samples, data) # initialize context sampling_ctx = init_default_sampling_context(sampling_ctx, data.shape[0]) # Expand mask and channels to match input module shape mask = sampling_ctx.mask.expand(data.shape[0], self.inputs.out_shape.features) channel_index = sampling_ctx.channel_index.expand(data.shape[0], self.inputs.out_shape.features) sampling_ctx.update(channel_index=channel_index, mask=mask) self.inputs.sample( data=data, is_mpe=is_mpe, cache=cache, sampling_ctx=sampling_ctx, ) return data
[docs] def marginalize( self, marg_rvs: list[int], prune: bool = True, cache: Optional[Dict[str, Any]] = None, ) -> None | Module: """Marginalize out specified random variables. Args: marg_rvs: List of random variable indices to marginalize. prune: Whether to prune the resulting module. cache: Cache dictionary for intermediate results. Returns: Marginalized module or None if fully marginalized. """ # compute module scope (same for all outputs) module_scope = self.scope mutual_rvs = set(module_scope.query).intersection(set(marg_rvs)) # Node scope is only being partially marginalized if mutual_rvs: # marginalize child modules marg_child_module = self.inputs.marginalize(marg_rvs, prune=prune, cache=cache) # if marginalized child is not None if marg_child_module: if prune and marg_child_module.out_shape.features == 1: return marg_child_module else: return self.__class__(inputs=marg_child_module, dim=self.dim, num_splits=self.num_splits) # if all children were marginalized, return None else: return None # if only a single input survived marginalization, return it if pruning is enabled else: return self