Wrapper Modules

Wrapper classes that adapt SPFlow modules for specific data formats and use cases.

The wrapper pattern enables flexible extension of SPFlow capabilities while maintaining compatibility with the core module interfaces.

Wrapper (Base Class)

Abstract base class for SPFlow module wrappers.

class spflow.modules.wrapper.base.Wrapper(module)[source]

Bases: Module, ABC

Abstract base class for SPFlow module wrappers.

Provides a foundation for creating wrapper modules that adapt existing SPFlow modules for specific use cases, data formats, or integration scenarios. Wrapper modules delegate most operations to the wrapped module while providing specialized functionality for specific contexts. All abstract methods from Module are delegated to wrapped module, concrete implementations should override specific methods as needed, and wrapper modules inherit scope and structure from wrapped modules.

The wrapper pattern enables: - Custom data format handling (images, sequences, etc.) - Preprocessing and postprocessing integration - External framework compatibility layers - Specialized input/output transformations

module

The wrapped SPFlow module.

Type:

Module

scope

Variable scope inherited from wrapped module.

Type:

Scope

__init__(module)[source]

Initialize wrapper with specified SPFlow module.

Creates a wrapper that delegates most operations to wrapped module while allowing specialized overrides for specific functionality. The wrapped module’s scope and structure are preserved, all abstract Module interface methods are delegated by default, and override specific methods to add custom wrapper functionality.

Parameters:

module (Module) – The SPFlow module to wrap. Can be any valid SPFlow module including complex circuit structures.

property device

Returns the device of the wrapped module.

Automatically determines the device where the wrapped module’s parameters are located, handling multi-device scenarios gracefully.

Returns:

Device where the wrapped module parameters are located.

Return type:

torch.device

property feature_to_scope: ndarray

Returns the mapping from features to scopes from the wrapped module.

Delegates to the wrapped module’s feature_to_scope property.

Returns:

Feature-to-scope mapping from the wrapped module.

Return type:

list[Scope]

ImageWrapper

Adapts SPFlow modules for 4D image data (batch, channels, height, width). Provides automatic conversion between flattened tensors and image format.

class spflow.modules.wrapper.image_wrapper.ImageWrapper(module, num_channel, height, width)[source]

Bases: Wrapper

Wrapper for adapting SPFlow modules to image data format.

Provides automatic conversion between 2D flattened tensors used by SPFlow modules and 4D image tensors (batch, channels, height, width) commonly used in computer vision applications. The wrapper automatically validates image dimensions against module scope, handles conversion between 2D and 4D tensor formats, and supports all standard SPFlow operations with image data while maintaining the spatial structure of image data and enabling the use of standard SPFlow modules.

module

Wrapped SPFlow module.

Type:

Module

num_channel

Number of image channels.

Type:

int

height

Image height in pixels.

Type:

int

width

Image width in pixels.

Type:

int

__init__(module, num_channel, height, width)[source]

Initialize image wrapper.

Creates a wrapper that adapts SPFlow modules for image data with automatic format conversion and validation.

Parameters:
  • module (Module) – SPFlow module to wrap.

  • num_channel (int) – Number of channels in image.

  • height (int) – Height of image in pixels.

  • width (int) – Width of image in pixels.

Raises:

StructureError – If module scope size doesn’t match image dimensions.

expectation_maximization(data, cache=None)[source]

Performs a single expectation maximization (EM) step for the wrapped module.

Parameters:
  • data (Tensor) – Four-dimensional PyTorch tensor containing the input data. Shape: (batch_size, num_channel, height, width).

  • cache (Cache | None) – Optional cache dictionary for memoization.

Return type:

None

flatten(tensor)[source]

Convert 4D image tensor to 2D flattened tensor.

Parameters:

tensor (Tensor) – 4D tensor of shape (batch, channels, height, width).

Returns:

2D flattened tensor of shape (batch, channels*height*width).

Raises:

ShapeError – If tensor is not 4D or channel dimension mismatch.

log_likelihood(data, cache=None)[source]

Computes log-likelihoods for the wrapped module given the data.

Missing values (i.e., NaN) are marginalized over.

Parameters:
  • data (Tensor) – Four-dimensional PyTorch tensor containing the input data. Shape: (batch_size, num_channel, height, width).

  • cache (Cache | None) – Optional cache dictionary for memoization.

Return type:

Tensor

Returns:

Two-dimensional PyTorch tensor containing the log-likelihoods of the input data. Each row corresponds to an input sample.

Raises:

ValueError – Data outside of support.

marginalize(marg_ctx, prune=True, cache=None)[source]

Marginalize out spatial dimensions from the wrapped module.

Parameters:
  • marg_ctx (MarginalizationContext) – MarginalizationContext specifying which dimensions to marginalize.

  • prune (bool) – Whether to prune the structure after marginalization.

  • cache (Cache | None) – Optional cache dictionary for memoization.

Return type:

Optional[ImageWrapper]

Returns:

New ImageWrapper with marginalized module and adjusted dimensions,

or None if module is fully marginalized.

maximum_likelihood_estimation(data, weights=None, cache=None)[source]

Update parameters via maximum likelihood estimation for the wrapped module.

Parameters:
  • data (Tensor) – Four-dimensional PyTorch tensor containing the input data. Shape: (batch_size, num_channel, height, width).

  • weights (Optional[Tensor]) – Optional sample weights tensor.

  • cache (Cache | None) – Optional cache dictionary for memoization.

Return type:

None

sample(num_samples=None, data=None, is_mpe=False, cache=None, sampling_ctx=None)[source]

Samples from the wrapped module, returning results in image format.

Parameters:
  • num_samples (int | None) – Number of samples to generate.

  • data (Tensor | None) – Four-dimensional PyTorch tensor containing the input data. Shape: (batch_size, num_channel, height, width).

  • is_mpe (bool) – Boolean value indicating whether to perform maximum a posteriori estimation (MPE). Defaults to False.

  • cache (Cache | None) – Optional cache dictionary for memoization.

  • sampling_ctx (Optional[SamplingContext]) – Optional sampling context containing the instances (i.e., rows) of data to fill with sampled values and the output indices of the node to sample from.

Return type:

Tensor

Returns:

Four-dimensional PyTorch tensor in image format containing the sampled values. Shape: (batch_size, num_channel, height, width).

to_image_format(tensor, batch=True)[source]

Convert 2D tensor to 4D image format.

Parameters:
  • tensor (Tensor) – 2D tensor to reshape.

  • batch (bool) – Whether to include batch dimension.

Returns:

4D tensor in image format.

Raises:

ShapeError – If tensor dimensions are incompatible.

MarginalizationContext

Context for spatial marginalization in image data.

class spflow.modules.wrapper.image_wrapper.MarginalizationContext(c=None, h=None, w=None)[source]

Bases: object

Context for spatial marginalization in image data.

Provides a structured way to specify which spatial dimensions (channels, height, width) to marginalize when working with image data in probabilistic circuits.

c

Channel indices to marginalize.

Type:

list[int]

h

Height indices to marginalize.

Type:

list[int]

w

Width indices to marginalize.

Type:

list[int]

__init__(c=None, h=None, w=None)[source]

Initialize marginalization context.

Parameters:
  • c (list[int]) – Channel indices to marginalize.

  • h (list[int]) – Height indices to marginalize.

  • w (list[int]) – Width indices to marginalize.