from __future__ import annotations
import numpy as np
import torch
from einops import rearrange, repeat
from torch import Tensor
from spflow.exceptions import ShapeError
from spflow.meta.data import Scope
from spflow.modules.module import Module
from spflow.modules.module_shape import ModuleShape
from spflow.modules.ops.split import Split, SplitMode
from spflow.modules.ops.split_consecutive import SplitConsecutive
from spflow.modules.ops.split_interleaved import SplitInterleaved
from spflow.modules.products.base_product import BaseProduct
from spflow.utils.cache import Cache, cached
[docs]
class ElementwiseProduct(BaseProduct):
"""Elementwise product with automatic broadcasting.
Multiplies inputs element-wise with broadcasting support. All input scopes
must be pairwise disjoint. Commonly used in RAT-SPN architectures.
"""
[docs]
def __init__(
self,
inputs: Module | tuple[Module, Module] | list[Module],
num_splits: int | None = 2,
split_mode: SplitMode | None = None,
) -> None:
"""Initialize elementwise product.
Args:
inputs:
List of Modules.
The scopes for all child modules need to be pair-wise disjoint.
(1) If `inputs` is a list of Modules, they have to be of disjoint scopes and have equal number of features or a single feature wich will the be broadcast and an equal number of channels or a single channel which will be broadcast.
Example shapes:
inputs = ((3, 4), (3, 4))
output = (3, 4)
inputs = ((3, 4), (3, 1))
output = (3, 4) # broadcasted
inputs = ((3, 4), (1, 4))
output = (3, 4) # broadcasted
inputs = ((3, 1), (1, 4))
output = (3, 4) # broadcasted
num_splits: Number of splits when wrapping single input in Split.
split_mode: Optional split configuration for single input mode.
Use SplitMode.consecutive() or SplitMode.interleaved().
Defaults to SplitMode.consecutive(num_splits=num_splits) if not specified.
Raises:
ValueError: Invalid arguments.
"""
# Handle single non-Split input: wrap with split_mode
if isinstance(inputs, Module) and not isinstance(inputs, (list, tuple, Split)):
if split_mode is not None:
inputs = split_mode.create(inputs)
else:
inputs = SplitConsecutive(inputs, num_splits=num_splits)
super().__init__(inputs=inputs)
# Check if all inputs either have equal number of out_channels or 1 (using in_shape.channels set by BaseProduct)
if not all(inp.out_shape.channels in (1, self.in_shape.channels) for inp in self.inputs):
raise ShapeError(
f"Inputs must have equal number of channels or one of them must be '1', but were {[inp.out_shape.channels for inp in self.inputs]}"
)
if self.num_splits is None:
self.num_splits = num_splits
self.check_shapes()
# Shape computation: compute out_shape based on elementwise product
input_features = self.inputs[0].out_shape.features
if input_features == 1:
out_features = 1
elif self.input_is_split:
out_features = int(input_features // self.num_splits)
else:
out_features = input_features
# out_channels is max since one input can have single channel (broadcast)
out_channels = self.in_shape.channels
self.out_shape = ModuleShape(out_features, out_channels, self.in_shape.repetitions)
[docs]
def check_shapes(self):
"""Check if input shapes are compatible for broadcasting."""
inputs = self.inputs
if self.input_is_split:
if self.num_splits != inputs[0].num_splits:
raise ValueError("num_splits must be the same for all inputs")
out_f = self.inputs[0].out_shape.features
out_c = self.inputs[0].out_shape.channels
if out_f == 1:
out_f_computed = 1
else:
out_f_computed = int(out_f // self.num_splits)
shapes = inputs[0].get_out_shapes((out_f_computed, out_c))
else:
shapes = [(inp.out_shape.features, inp.out_shape.channels) for inp in inputs]
if not shapes:
return False # No shapes to check
# Extract dimensions
dim0_values = [shape[0] for shape in shapes]
dim1_values = [shape[1] for shape in shapes]
# Condition 1: All shapes are the same
if all(shape == shapes[0] for shape in shapes):
return True
# Condition 2: dim0 is the same and dim1 is the same or has the value 1
if len(set(dim0_values) - {1}) == 1:
return True
if len(set(dim1_values) - {1}) == 1:
return True
# Condition 4: In dim0 every value except one has the value 1,
# and in dim1 every value except one has the value 1
dim0_non_ones = [value for value in dim0_values if value != 1]
dim1_non_ones = [value for value in dim1_values if value != 1]
if len(dim0_non_ones) <= 1 and len(dim1_non_ones) <= 1:
return True
# If none of the conditions are satisfied
raise ShapeError(f"the shapes of the inputs {shapes} are not broadcastable")
@property
def feature_to_scope(self) -> np.ndarray:
out = []
for r in range(self.out_shape.repetitions):
if self.input_is_split:
scope_lists_r = self.inputs[0].feature_to_scope[
..., r
] # Shape: (num_features_per_split, num_splits)
scope_lists_r = [scope_lists_r[:, i] for i in range(self.num_splits)]
else:
scope_lists_r = [
module.feature_to_scope[..., r] for module in self.inputs
] # Shape: (num_features_per_split, num_splits)
feature_to_scope_r = []
# Group elements by index
grouped_scopes_r = list(zip(*scope_lists_r))
for scopes_r in grouped_scopes_r:
feature_to_scope_r.append(Scope.join_all(scopes_r))
out.append(np.array(feature_to_scope_r))
return np.stack(out, axis=1)
[docs]
def map_out_channels_to_in_channels(self, output_ids: Tensor) -> Tensor:
"""Map output channel indices to input channel indices.
Args:
output_ids: Tensor of output channel indices to map.
Returns:
Tensor: Mapped input channel indices.
"""
if self.input_is_split:
num_splits = self.num_splits
if output_ids.is_floating_point():
if isinstance(self.inputs[0], SplitConsecutive):
expanded_ids = repeat(output_ids, "b f c -> b i f c", i=num_splits)
return rearrange(expanded_ids, "b i f c -> b (i f) 1 c")
elif isinstance(self.inputs[0], SplitInterleaved):
expanded_ids = repeat(output_ids, "b f c -> b f i c", i=num_splits)
return rearrange(expanded_ids, "b f i c -> b (f i) 1 c")
else:
raise NotImplementedError("Other Split types are not implemented yet.")
if isinstance(self.inputs[0], SplitConsecutive):
expanded_ids = repeat(output_ids, "b f -> b i f", i=num_splits)
return rearrange(expanded_ids, "b i f -> b (i f) 1")
elif isinstance(self.inputs[0], SplitInterleaved):
expanded_ids = repeat(output_ids, "b f -> b f i", i=num_splits)
return rearrange(expanded_ids, "b f i -> b (f i) 1")
else:
raise NotImplementedError("Other Split types are not implemented yet.")
num_inputs = len(self.inputs)
if not output_ids.is_floating_point():
mapped = repeat(output_ids, "b f -> b f i", i=num_inputs)
for i, inp in enumerate(self.inputs):
if int(inp.out_shape.channels) == 1:
mapped[..., i] = 0
return mapped
batch_size = int(output_ids.shape[0])
num_features = int(output_ids.shape[1])
max_channels = int(self.in_shape.channels)
mapped = output_ids.new_zeros((batch_size, num_features, num_inputs, max_channels))
for i, inp in enumerate(self.inputs):
child_channels = int(inp.out_shape.channels)
if child_channels == 1:
mapped[:, :, i, 0] = 1.0
else:
mapped[:, :, i, :child_channels] = output_ids[:, :, :child_channels]
return mapped
[docs]
def map_out_mask_to_in_mask(self, mask: Tensor) -> Tensor:
"""Map output mask to input mask.
Args:
mask: Output mask tensor to map.
Returns:
Tensor: Mapped input mask tensor.
"""
if self.input_is_split:
num_splits = self.num_splits
if isinstance(self.inputs[0], SplitConsecutive):
expanded_mask = repeat(mask, "b f -> b i f", i=num_splits)
return rearrange(expanded_mask, "b i f -> b (i f) 1")
elif isinstance(self.inputs[0], SplitInterleaved):
expanded_mask = repeat(mask, "b f -> b f i", i=num_splits)
return rearrange(expanded_mask, "b f i -> b (f i) 1")
else:
raise NotImplementedError("Other Split types are not implemented yet.")
else:
num_splits = len(self.inputs)
return repeat(mask, "b f -> b f i", i=num_splits)
[docs]
@cached
def log_likelihood(
self,
data: Tensor,
cache: Cache | None = None,
) -> Tensor:
"""Compute log likelihood by element-wise summing inputs.
Args:
data: Input data tensor.
cache: Optional cache for storing intermediate computations.
Returns:
Tensor: Computed log likelihood values.
"""
# initialize cache
lls = self._get_input_log_likelihoods(data, cache)
# Check if we need to expand to enable broadcasting along channels
for i, ll in enumerate(lls):
if ll.shape[2] == 1:
num_output_channels = self.out_shape.channels
if ll.ndim == 4:
lls[i] = repeat(ll, "b f 1 r -> b f c r", c=num_output_channels)
else:
lls[i] = repeat(ll, "b f 1 -> b f c", c=num_output_channels)
# Compute the elementwise sum of left and right split
output = torch.sum(torch.stack(lls, dim=-1), dim=-1)
output = output.view(
output.size(0), self.out_shape.features, self.out_shape.channels, self.out_shape.repetitions
)
return output