Wrapper Modules¶
Wrapper classes that adapt SPFlow modules for specific data formats and use cases.
The wrapper pattern enables flexible extension of SPFlow capabilities while maintaining compatibility with the core module interfaces.
Wrapper (Base Class)¶
Abstract base class for SPFlow module wrappers.
- class spflow.modules.wrapper.base.Wrapper(module)[source]¶
-
Abstract base class for SPFlow module wrappers.
Provides a foundation for creating wrapper modules that adapt existing SPFlow modules for specific use cases, data formats, or integration scenarios. Wrapper modules delegate most operations to the wrapped module while providing specialized functionality for specific contexts. All abstract methods from Module are delegated to wrapped module, concrete implementations should override specific methods as needed, and wrapper modules inherit scope and structure from wrapped modules.
The wrapper pattern enables: - Custom data format handling (images, sequences, etc.) - Preprocessing and postprocessing integration - External framework compatibility layers - Specialized input/output transformations
- module¶
The wrapped SPFlow module.
- Type:
Module
- scope¶
Variable scope inherited from wrapped module.
- Type:
Scope
- __init__(module)[source]¶
Initialize wrapper with specified SPFlow module.
Creates a wrapper that delegates most operations to wrapped module while allowing specialized overrides for specific functionality. The wrapped module’s scope and structure are preserved, all abstract Module interface methods are delegated by default, and override specific methods to add custom wrapper functionality.
- Parameters:
module (
Module) – The SPFlow module to wrap. Can be any valid SPFlow module including complex circuit structures.
- property device¶
Returns the device of the wrapped module.
Automatically determines the device where the wrapped module’s parameters are located, handling multi-device scenarios gracefully.
- Returns:
Device where the wrapped module parameters are located.
- Return type:
ImageWrapper¶
Adapts SPFlow modules for 4D image data (batch, channels, height, width). Provides automatic conversion between flattened tensors and image format.
- class spflow.modules.wrapper.image_wrapper.ImageWrapper(module, num_channel, height, width)[source]¶
Bases:
WrapperWrapper for adapting SPFlow modules to image data format.
Provides automatic conversion between 2D flattened tensors used by SPFlow modules and 4D image tensors (batch, channels, height, width) commonly used in computer vision applications. The wrapper automatically validates image dimensions against module scope, handles conversion between 2D and 4D tensor formats, and supports all standard SPFlow operations with image data while maintaining the spatial structure of image data and enabling the use of standard SPFlow modules.
- module¶
Wrapped SPFlow module.
- Type:
Module
- __init__(module, num_channel, height, width)[source]¶
Initialize image wrapper.
Creates a wrapper that adapts SPFlow modules for image data with automatic format conversion and validation.
- Parameters:
- Raises:
StructureError – If module scope size doesn’t match image dimensions.
- expectation_maximization(data, cache=None)[source]¶
Performs a single expectation maximization (EM) step for the wrapped module.
- flatten(tensor)[source]¶
Convert 4D image tensor to 2D flattened tensor.
- Parameters:
tensor (
Tensor) – 4D tensor of shape (batch, channels, height, width).- Returns:
2D flattened tensor of shape (batch, channels*height*width).
- Raises:
ShapeError – If tensor is not 4D or channel dimension mismatch.
- log_likelihood(data, cache=None)[source]¶
Computes log-likelihoods for the wrapped module given the data.
Missing values (i.e., NaN) are marginalized over.
- Parameters:
- Return type:
- Returns:
Two-dimensional PyTorch tensor containing the log-likelihoods of the input data. Each row corresponds to an input sample.
- Raises:
ValueError – Data outside of support.
- marginalize(marg_ctx, prune=True, cache=None)[source]¶
Marginalize out spatial dimensions from the wrapped module.
- Parameters:
marg_ctx (
MarginalizationContext) – MarginalizationContext specifying which dimensions to marginalize.prune (
bool) – Whether to prune the structure after marginalization.cache (
Cache|None) – Optional cache dictionary for memoization.
- Return type:
- Returns:
- New ImageWrapper with marginalized module and adjusted dimensions,
or None if module is fully marginalized.
- maximum_likelihood_estimation(data, weights=None, cache=None)[source]¶
Update parameters via maximum likelihood estimation for the wrapped module.
- sample(num_samples=None, data=None, is_mpe=False, cache=None, sampling_ctx=None)[source]¶
Samples from the wrapped module, returning results in image format.
- Parameters:
data (
Tensor|None) – Four-dimensional PyTorch tensor containing the input data. Shape: (batch_size, num_channel, height, width).is_mpe (
bool) – Boolean value indicating whether to perform maximum a posteriori estimation (MPE). Defaults to False.cache (
Cache|None) – Optional cache dictionary for memoization.sampling_ctx (
Optional[SamplingContext]) – Optional sampling context containing the instances (i.e., rows) ofdatato fill with sampled values and the output indices of the node to sample from.
- Return type:
- Returns:
Four-dimensional PyTorch tensor in image format containing the sampled values. Shape: (batch_size, num_channel, height, width).
- to_image_format(tensor, batch=True)[source]¶
Convert 2D tensor to 4D image format.
- Parameters:
- Returns:
4D tensor in image format.
- Raises:
ShapeError – If tensor dimensions are incompatible.
MarginalizationContext¶
Context for spatial marginalization in image data.
- class spflow.modules.wrapper.image_wrapper.MarginalizationContext(c=None, h=None, w=None)[source]¶
Bases:
objectContext for spatial marginalization in image data.
Provides a structured way to specify which spatial dimensions (channels, height, width) to marginalize when working with image data in probabilistic circuits.
- c¶
Channel indices to marginalize.
- Type:
list[int]
- h¶
Height indices to marginalize.
- Type:
list[int]
- w¶
Width indices to marginalize.
- Type:
list[int]