Wrapper Modules¶
Wrapper classes that adapt SPFlow modules for specific data formats and use cases.
The wrapper pattern enables flexible extension of SPFlow capabilities while maintaining compatibility with the core module interfaces.
Wrapper (Base Class)¶
Abstract base class for SPFlow module wrappers.
- class spflow.modules.wrapper.base.Wrapper(module)[source]¶
Bases:
Module,ABCAbstract base class for SPFlow module wrappers.
Provides a foundation for creating wrapper modules that adapt existing SPFlow modules for specific use cases, data formats, or integration scenarios. Wrapper modules delegate most operations to the wrapped module while providing specialized functionality for specific contexts. All abstract methods from Module are delegated to wrapped module, concrete implementations should override specific methods as needed, and wrapper modules inherit scope and structure from wrapped modules.
The wrapper pattern enables: - Custom data format handling (images, sequences, etc.) - Preprocessing and postprocessing integration - External framework compatibility layers - Specialized input/output transformations
- module¶
The wrapped SPFlow module.
- Type:
Module
- scope¶
Variable scope inherited from wrapped module.
- Type:
Scope
- __init__(module)[source]¶
Initialize wrapper with specified SPFlow module.
Creates a wrapper that delegates most operations to wrapped module while allowing specialized overrides for specific functionality. The wrapped module’s scope and structure are preserved, all abstract Module interface methods are delegated by default, and override specific methods to add custom wrapper functionality.
- Parameters:
module (
Module) – The SPFlow module to wrap. Can be any valid SPFlow module including complex circuit structures.
- property device¶
Returns the device of the wrapped module.
Automatically determines the device where the wrapped module’s parameters are located, handling multi-device scenarios gracefully.
- Returns:
Device where the wrapped module parameters are located.
- Return type:
ImageWrapper¶
Adapts SPFlow modules for 4D image data (batch, channels, height, width). Provides automatic conversion between flattened tensors and image format.
- class spflow.modules.wrapper.image_wrapper.ImageWrapper(module, num_channel, height, width)[source]¶
Bases:
WrapperWrapper for adapting SPFlow modules to image data format.
Provides automatic conversion between 2D flattened tensors used by SPFlow modules and 4D image tensors (batch, channels, height, width) commonly used in computer vision applications. The wrapper automatically validates image dimensions against module scope, handles conversion between 2D and 4D tensor formats, and supports all standard SPFlow operations with image data while maintaining the spatial structure of image data and enabling the use of standard SPFlow modules.
- module¶
Wrapped SPFlow module.
- Type:
Module
- __init__(module, num_channel, height, width)[source]¶
Initialize image wrapper.
Creates a wrapper that adapts SPFlow modules for image data with automatic format conversion and validation.
- Parameters:
- Raises:
StructureError – If module scope size doesn’t match image dimensions.
- flatten(tensor)[source]¶
Convert 4D image tensor to 2D flattened tensor.
- Parameters:
tensor (
Tensor) – 4D tensor of shape (batch, channels, height, width).- Returns:
2D flattened tensor of shape (batch, channels*height*width).
- Raises:
ShapeError – If tensor is not 4D or channel dimension mismatch.
- log_likelihood(data, cache=None)[source]¶
Computes log-likelihoods for the wrapped module given the data.
Missing values (i.e., NaN) are marginalized over.
- Parameters:
- Return type:
- Returns:
Two-dimensional PyTorch tensor containing the log-likelihoods of the input data. Each row corresponds to an input sample.
- Raises:
ValueError – Data outside of support.
- marginalize(marg_ctx, prune=True, cache=None)[source]¶
Marginalize out spatial dimensions from the wrapped module.
- Parameters:
marg_ctx (
MarginalizationContext) – MarginalizationContext specifying which dimensions to marginalize.prune (
bool) – Whether to prune the structure after marginalization.cache (
Cache|None) – Optional cache dictionary for memoization.
- Return type:
- Returns:
- New ImageWrapper with marginalized module and adjusted dimensions,
or None if module is fully marginalized.
- sample(num_samples=None, data=None, is_mpe=False, cache=None, return_leaf_params=False)[source]¶
Samples from the wrapped module, returning results in image format.
- Parameters:
data (
Tensor|None) – Four-dimensional PyTorch tensor containing the input data. Shape: (batch_size, num_channel, height, width).is_mpe (
bool) – Boolean value indicating whether to perform maximum a posteriori estimation (MPE). Defaults to False.cache (
Cache|None) – Optional cache dictionary for memoization.return_leaf_params (
bool) – Whether to return leaf distribution parameters collected during traversal.
- Return type:
- Returns:
Four-dimensional PyTorch tensor in image format containing the sampled values. Shape: (batch_size, num_channel, height, width).
- to_image_format(tensor, batch=True)[source]¶
Convert 2D tensor to 4D image format.
- Parameters:
- Returns:
4D tensor in image format.
- Raises:
ShapeError – If tensor dimensions are incompatible.
JointLogLikelihood¶
Wrapper that exposes the joint log-likelihood as a single feature, i.e. reduces a
(batch, features, channels, repetitions) tensor to (batch, 1, channels, repetitions)
by summing over the feature axis.
This is a convenience adapter to make some multivariate leaves (e.g. CLTree) behave like
typical “root” modules in SPFlow. It is a tensor reduction and does not introduce any
additional independence/factorization semantics.
- class spflow.zoo.cms.JointLogLikelihood(module)[source]¶
Bases:
WrapperExpose a wrapped module’s joint log-likelihood as a single feature.
- log_likelihood(data, cache=None)[source]¶
Compute log likelihood P(data | module).
Computes log probability of input data under this module’s distribution. Uses log-space for numerical stability. Results should be cached for efficiency.
- Parameters:
data (
Tensor) – Input data of shape (batch_size, num_features). NaN values indicate missing values to marginalize over.cache (
Cache | None, optional) – Cache for intermediate computations. Defaults to None.
- Returns:
Log-likelihood of shape (batch_size, out_features, out_channels).
- Return type:
Tensor
- Raises:
ValueError – If input data shape is incompatible with module scope.
- marginalize(marg_rvs, prune=True, cache=None)[source]¶
Structurally marginalize out specified random variables from the module.
Computes a new module representing the marginal distribution by integrating out the specified variables from the structure. For data-level marginalization, use NaNs in
log_likelihoodinputs.- Parameters:
marg_rvs (
list[int]) – Random variable indices to marginalize out.prune (
bool, optional) – Whether to prune unnecessary modules during marginalization. Defaults to True.cache (
Cache | None, optional) – Cache for intermediate computations. Defaults to None.
- Returns:
Marginalized module, or None if all variables are marginalized out.
- Return type:
Module | None
- Raises:
ValueError – If marginalization variables are not in the module’s scope.
- sample(num_samples=None, data=None, is_mpe=False, cache=None, return_leaf_params=False)[source]¶
Generate samples from the module’s probability distribution.
Supports both random sampling and MAP inference (via is_mpe flag). Handles conditional sampling through evidence in data tensor.
- Parameters:
num_samples (
int | None, optional) – Number of samples to generate. Defaults to 1.data (
Tensor | None, optional) – Pre-allocated tensor with NaN values indicating where to sample. If None, creates new tensor. Defaults to None.is_mpe (
bool, optional) – If True, returns most probable values instead of random samples. Defaults to False.cache (
Cache | None, optional) – Cache for intermediate computations. Defaults to None.return_leaf_params (
bool, optional) – IfTrue, also return leaf distribution parameters gathered during traversal.
- Return type:
- Returns:
Sampled values of shape (batch_size, num_features), optionally with collected leaf-parameter records.
- Raises:
ValueError – If sampling parameters are incompatible.
MarginalizationContext¶
Context for spatial marginalization in image data.
- class spflow.modules.wrapper.image_wrapper.MarginalizationContext(c=None, h=None, w=None)[source]¶
Bases:
objectContext for spatial marginalization in image data.
Provides a structured way to specify which spatial dimensions (channels, height, width) to marginalize when working with image data in probabilistic circuits.
- c¶
Channel indices to marginalize.
- Type:
list[int]
- h¶
Height indices to marginalize.
- Type:
list[int]
- w¶
Width indices to marginalize.
- Type:
list[int]