Source code for spflow.modules.module_shape
"""Shape representation for SPFlow modules.
Provides the ModuleShape dataclass for representing tensor shapes
(excluding batch dimension) flowing through probabilistic circuit modules.
"""
from dataclasses import dataclass
[docs]
@dataclass(frozen=True)
class ModuleShape:
"""Represents tensor shape (excluding batch dimension).
Shapes in SPFlow modules are 3-dimensional: (features, channels, repetitions).
This dataclass provides named access and supports iteration and indexing.
Attributes:
features: Number of features (random variables/scope size).
channels: Number of parallel channels/distributions.
repetitions: Number of independent repetitions.
Examples:
>>> shape = ModuleShape(4, 8, 2)
>>> shape.features
4
>>> shape[1] # channels
8
>>> tuple(shape)
(4, 8, 2)
"""
features: int
channels: int
repetitions: int
def __iter__(self):
"""Iterate over shape dimensions."""
return iter((self.features, self.channels, self.repetitions))
def __getitem__(self, idx: int) -> int:
"""Index into shape dimensions."""
return (self.features, self.channels, self.repetitions)[idx]
def __repr__(self) -> str:
"""Return concise string representation."""
return f"Shape(F={self.features}, C={self.channels}, R={self.repetitions})"