from __future__ import annotations
import numpy as np
import torch
from torch import Tensor
from spflow.exceptions import ShapeError
from spflow.meta.data import Scope
from spflow.modules.module import Module
from spflow.modules.module_shape import ModuleShape
from spflow.modules.ops.split import Split, SplitMode
from spflow.modules.ops.split_consecutive import SplitConsecutive
from spflow.modules.ops.split_interleaved import SplitInterleaved
from spflow.modules.products.base_product import BaseProduct
from spflow.utils.cache import Cache, cached
[docs]
class ElementwiseProduct(BaseProduct):
"""Elementwise product with automatic broadcasting.
Multiplies inputs element-wise with broadcasting support. All input scopes
must be pairwise disjoint. Commonly used in RAT-SPN architectures.
"""
[docs]
def __init__(
self,
inputs: Module | tuple[Module, Module] | list[Module],
num_splits: int | None = 2,
split_mode: SplitMode | None = None,
) -> None:
"""Initialize elementwise product.
Args:
inputs:
List of Modules.
The scopes for all child modules need to be pair-wise disjoint.
(1) If `inputs` is a list of Modules, they have to be of disjoint scopes and have equal number of features or a single feature wich will the be broadcast and an equal number of channels or a single channel which will be broadcast.
Example shapes:
inputs = ((3, 4), (3, 4))
output = (3, 4)
inputs = ((3, 4), (3, 1))
output = (3, 4) # broadcasted
inputs = ((3, 4), (1, 4))
output = (3, 4) # broadcasted
inputs = ((3, 1), (1, 4))
output = (3, 4) # broadcasted
num_splits: Number of splits when wrapping single input in Split.
split_mode: Optional split configuration for single input mode.
Use SplitMode.consecutive() or SplitMode.interleaved().
Defaults to SplitMode.consecutive(num_splits=num_splits) if not specified.
Raises:
ValueError: Invalid arguments.
"""
# Handle single non-Split input: wrap with split_mode
if isinstance(inputs, Module) and not isinstance(inputs, (list, tuple, Split)):
if split_mode is not None:
inputs = split_mode.create(inputs)
else:
inputs = SplitConsecutive(inputs, num_splits=num_splits)
super().__init__(inputs=inputs)
# Check if all inputs either have equal number of out_channels or 1 (using in_shape.channels set by BaseProduct)
if not all(inp.out_shape.channels in (1, self.in_shape.channels) for inp in self.inputs):
raise ShapeError(
f"Inputs must have equal number of channels or one of them must be '1', but were {[inp.out_shape.channels for inp in self.inputs]}"
)
if self.num_splits is None:
self.num_splits = num_splits
self.check_shapes()
# Shape computation: compute out_shape based on elementwise product
input_features = self.inputs[0].out_shape.features
if input_features == 1:
out_features = 1
elif self.input_is_split:
out_features = int(input_features // self.num_splits)
else:
out_features = input_features
# out_channels is max since one input can have single channel (broadcast)
out_channels = self.in_shape.channels
self.out_shape = ModuleShape(out_features, out_channels, self.in_shape.repetitions)
[docs]
def check_shapes(self):
"""Check if input shapes are compatible for broadcasting."""
inputs = self.inputs
if self.input_is_split:
if self.num_splits != inputs[0].num_splits:
raise ValueError("num_splits must be the same for all inputs")
out_f = self.inputs[0].out_shape.features
out_c = self.inputs[0].out_shape.channels
if out_f == 1:
out_f_computed = 1
else:
out_f_computed = int(out_f // self.num_splits)
shapes = inputs[0].get_out_shapes((out_f_computed, out_c))
else:
shapes = [(inp.out_shape.features, inp.out_shape.channels) for inp in inputs]
if not shapes:
return False # No shapes to check
# Extract dimensions
dim0_values = [shape[0] for shape in shapes]
dim1_values = [shape[1] for shape in shapes]
# Condition 1: All shapes are the same
if all(shape == shapes[0] for shape in shapes):
return True
# Condition 2: dim0 is the same and dim1 is the same or has the value 1
if len(set(dim0_values) - {1}) == 1:
return True
if len(set(dim1_values) - {1}) == 1:
return True
# Condition 4: In dim0 every value except one has the value 1,
# and in dim1 every value except one has the value 1
dim0_non_ones = [value for value in dim0_values if value != 1]
dim1_non_ones = [value for value in dim1_values if value != 1]
if len(dim0_non_ones) <= 1 and len(dim1_non_ones) <= 1:
return True
# If none of the conditions are satisfied
raise ShapeError(f"the shapes of the inputs {shapes} are not broadcastable")
@property
def feature_to_scope(self) -> np.ndarray:
out = []
for r in range(self.out_shape.repetitions):
if self.input_is_split:
scope_lists_r = self.inputs[0].feature_to_scope[
..., r
] # Shape: (num_features_per_split, num_splits)
scope_lists_r = [scope_lists_r[:, i] for i in range(self.num_splits)]
else:
scope_lists_r = [
module.feature_to_scope[..., r] for module in self.inputs
] # Shape: (num_features_per_split, num_splits)
feature_to_scope_r = []
# Group elements by index
grouped_scopes_r = list(zip(*scope_lists_r))
for scopes_r in grouped_scopes_r:
feature_to_scope_r.append(Scope.join_all(scopes_r))
out.append(np.array(feature_to_scope_r))
return np.stack(out, axis=1)
[docs]
def map_out_channels_to_in_channels(self, output_ids: Tensor) -> Tensor:
"""Map output channel indices to input channel indices.
Args:
output_ids: Tensor of output channel indices to map.
Returns:
Tensor: Mapped input channel indices.
"""
if self.input_is_split:
num_splits = self.num_splits
if isinstance(self.inputs[0], SplitConsecutive):
return output_ids.repeat((1, num_splits)).unsqueeze(-1)
elif isinstance(self.inputs[0], SplitInterleaved):
return output_ids.repeat_interleave(num_splits, dim=1).unsqueeze(-1)
else:
raise NotImplementedError("Other Split types are not implemented yet.")
else:
num_splits = len(self.inputs)
return output_ids.unsqueeze(-1).repeat(1, 1, num_splits)
[docs]
def map_out_mask_to_in_mask(self, mask: Tensor) -> Tensor:
"""Map output mask to input mask.
Args:
mask: Output mask tensor to map.
Returns:
Tensor: Mapped input mask tensor.
"""
if self.input_is_split:
num_splits = self.num_splits
if isinstance(self.inputs[0], SplitConsecutive):
return mask.repeat((1, num_splits)).unsqueeze(-1)
elif isinstance(self.inputs[0], SplitInterleaved):
return mask.repeat_interleave(num_splits, dim=1).unsqueeze(-1)
else:
raise NotImplementedError("Other Split types are not implemented yet.")
else:
num_splits = len(self.inputs)
return mask.unsqueeze(-1).repeat(1, 1, num_splits)
[docs]
@cached
def log_likelihood(
self,
data: Tensor,
cache: Cache | None = None,
) -> Tensor:
"""Compute log likelihood by element-wise summing inputs.
Args:
data: Input data tensor.
cache: Optional cache for storing intermediate computations.
Returns:
Tensor: Computed log likelihood values.
"""
# initialize cache
lls = self._get_input_log_likelihoods(data, cache)
# Check if we need to expand to enable broadcasting along channels
for i, ll in enumerate(lls):
if ll.shape[2] == 1:
if ll.ndim == 4:
lls[i] = ll.expand(-1, -1, self.out_shape.channels, -1)
else:
lls[i] = ll.expand(-1, -1, self.out_shape.channels)
# Compute the elementwise sum of left and right split
output = torch.sum(torch.stack(lls, dim=-1), dim=-1)
output = output.view(
output.size(0), self.out_shape.features, self.out_shape.channels, self.out_shape.repetitions
)
return output