Source code for spflow.modules.wrapper.base

"""Base wrapper classes for SPFlow module adaptation.

This module provides abstract base classes and utilities for creating wrapper
modules that adapt existing SPFlow modules for specific use cases or data formats.
Wrapper modules enable the extension of SPFlow functionality without modifying
core module implementations. The wrapper pattern enables flexible extension of SPFlow capabilities while
maintaining compatibility with the core module interfaces.
"""

from abc import ABC
import numpy as np

from spflow.meta.data import Scope
from spflow.modules.module import Module


[docs] class Wrapper(Module, ABC): """Abstract base class for SPFlow module wrappers. Provides a foundation for creating wrapper modules that adapt existing SPFlow modules for specific use cases, data formats, or integration scenarios. Wrapper modules delegate most operations to the wrapped module while providing specialized functionality for specific contexts. All abstract methods from Module are delegated to wrapped module, concrete implementations should override specific methods as needed, and wrapper modules inherit scope and structure from wrapped modules. The wrapper pattern enables: - Custom data format handling (images, sequences, etc.) - Preprocessing and postprocessing integration - External framework compatibility layers - Specialized input/output transformations Attributes: module (Module): The wrapped SPFlow module. scope (Scope): Variable scope inherited from wrapped module. """
[docs] def __init__(self, module: Module): """Initialize wrapper with specified SPFlow module. Creates a wrapper that delegates most operations to wrapped module while allowing specialized overrides for specific functionality. The wrapped module's scope and structure are preserved, all abstract Module interface methods are delegated by default, and override specific methods to add custom wrapper functionality. Args: module (Module): The SPFlow module to wrap. Can be any valid SPFlow module including complex circuit structures. """ super().__init__() self.module = module self.scope = module.scope # Shape computation: delegate to wrapped module self.in_shape = self.module.in_shape self.out_shape = self.module.out_shape
@property def feature_to_scope(self) -> np.ndarray: """Returns the mapping from features to scopes from the wrapped module. Delegates to the wrapped module's feature_to_scope property. Returns: list[Scope]: Feature-to-scope mapping from the wrapped module. """ return self.module.feature_to_scope @property def device(self): """Returns the device of the wrapped module. Automatically determines the device where the wrapped module's parameters are located, handling multi-device scenarios gracefully. Returns: torch.device: Device where the wrapped module parameters are located. """ return next(iter(self.module.parameters())).device def extra_repr(self) -> str: """Return a string representation of the wrapper module. Provides a concise representation showing the output features (D), output channels (C), and number of repetitions (R) for debugging and logging purposes. Returns: str: String representation in format "D={out_features}, C={out_channels}, R={num_repetitions}". """ return f"D={self.out_shape.features}, C={self.out_shape.channels}, R={self.out_shape.repetitions}"