from __future__ import annotations
from itertools import product
import numpy as np
import torch
from einops import rearrange, repeat
from torch import Tensor
from torch.nn import functional as F
from spflow.exceptions import ShapeError
from spflow.meta.data import Scope
from spflow.modules.module import Module
from spflow.modules.module_shape import ModuleShape
from spflow.modules.ops.split import Split, SplitMode
from spflow.modules.ops.split_consecutive import SplitConsecutive
from spflow.modules.ops.split_interleaved import SplitInterleaved
from spflow.modules.products.base_product import BaseProduct
from spflow.utils.cache import Cache, cached
[docs]
class OuterProduct(BaseProduct):
"""Outer product creating all pairwise channel combinations.
Computes Cartesian product of input channels. Output channels equal
product of input channels. All input scopes must be pairwise disjoint.
Attributes:
unraveled_channel_indices (Tensor): Mapping from output to input channel pairs.
"""
[docs]
def __init__(
self,
inputs: list[Module],
num_splits: int | None = 2,
split_mode: SplitMode | None = None,
) -> None:
"""Initialize outer product.
Args:
inputs: Modules with pairwise disjoint scopes.
num_splits: Number of splits for input operations.
split_mode: Optional split configuration for single input mode.
Use SplitMode.consecutive() or SplitMode.interleaved().
Defaults to SplitMode.consecutive(num_splits=num_splits) if not specified.
"""
# Handle single non-Split input: wrap with split_mode
if isinstance(inputs, Module) and not isinstance(inputs, (list, tuple, Split)):
if split_mode is not None:
inputs = split_mode.create(inputs)
else:
inputs = SplitConsecutive(inputs, num_splits=num_splits)
super().__init__(inputs=inputs)
self.check_shapes()
if len(self.inputs) == 1:
if num_splits is None or num_splits <= 1:
raise ValueError("num_splits must be at least 2 when input is a single module")
self.num_splits = num_splits
# Store unraveled channel indices from actual child channel counts.
if self.input_is_split:
split_channels = int(self.inputs[0].out_shape.channels)
child_channel_counts = [split_channels for _ in range(self.num_splits)]
else:
child_channel_counts = [int(inp.out_shape.channels) for inp in self.inputs]
self._child_channel_counts = tuple(child_channel_counts)
unraveled_channel_indices = list(
product(*[list(range(count)) for count in self._child_channel_counts])
)
self.register_buffer(
name="unraveled_channel_indices",
tensor=torch.tensor(unraveled_channel_indices, dtype=torch.long),
)
# Shape computation: compute out_shape based on outer product of channels
input_features = self.inputs[0].out_shape.features
if self.input_is_split:
out_features = int(input_features // self.num_splits)
else:
out_features = input_features
out_channels = len(unraveled_channel_indices)
self.out_shape = ModuleShape(out_features, out_channels, self.in_shape.repetitions)
[docs]
def check_shapes(self):
"""Check if input shapes are compatible for outer product.
Returns:
bool: True if shapes are compatible, False if no shapes to check.
Raises:
ValueError: If input shapes are not broadcastable.
"""
# Compute out_features locally
input_features = self.inputs[0].out_shape.features
if self.input_is_split:
out_features = int(input_features // self.num_splits)
else:
out_features = input_features
# Compute out_channels as product of input channels
ocs = 1
for inp in self.inputs:
ocs *= inp.out_shape.channels
if len(self.inputs) == 1:
ocs = ocs**self.num_splits
out_channels = ocs
if self.input_is_split:
if self.num_splits != self.inputs[0].num_splits:
raise ValueError("num_splits must be the same for all inputs")
shapes = self.inputs[0].get_out_shapes((out_features, out_channels))
else:
shapes = [(inp.out_shape.features, inp.out_shape.channels) for inp in self.inputs]
if not shapes:
return False # No shapes to check
# Extract dimensions
dim0_values = [shape[0] for shape in shapes]
dim1_values = [shape[1] for shape in shapes]
# Check if all shapes have the same first dimension
if len(set(dim0_values)) == 1:
return True
# Check if all shapes have the same second dimension
if len(set(dim1_values)) == 1:
return True
# Check if all but one of the first dimensions are 1
if dim0_values.count(1) == len(dim0_values) - 1:
return True
# Check if all but one of the second dimensions are 1
if dim1_values.count(1) == len(dim1_values) - 1:
return True
# If none of the conditions are satisfied
raise ShapeError(f"the shapes of the inputs {shapes} are not broadcastable")
@property
def feature_to_scope(self) -> list[Scope]:
out = []
for r in range(self.out_shape.repetitions):
if isinstance(self.inputs, Split):
scope_lists_r = self.inputs.feature_to_scope[
..., r
] # Shape: (num_features_per_split, num_splits)
scope_lists_r = [scope_lists_r[:, i] for i in range(self.num_splits)]
else:
scope_lists_r = [
module.feature_to_scope[..., r] for module in self.inputs
] # Shape: (num_features_per_split, num_splits)
outer_product_r = list(product(*scope_lists_r))
feature_to_scope_r = []
for scopes_r in outer_product_r:
feature_to_scope_r.append(Scope.join_all(scopes_r))
out.append(np.array(feature_to_scope_r))
return np.stack(out, axis=1)
[docs]
def map_out_channels_to_in_channels(self, output_ids: Tensor) -> Tensor:
"""Map output channel indices to input channel indices.
Args:
output_ids: Tensor of output channel indices to map.
Returns:
Tensor: Mapped input channel indices corresponding to the output channels.
Raises:
NotImplementedError: If split type is not supported.
"""
if not output_ids.is_floating_point():
if self.input_is_split:
mapped_ids = self.unraveled_channel_indices[output_ids]
if isinstance(self.inputs[0], SplitConsecutive):
return rearrange(mapped_ids, "b f i -> b (i f) 1")
elif isinstance(self.inputs[0], SplitInterleaved):
return rearrange(mapped_ids, "b f i -> b (f i) 1")
else:
raise NotImplementedError("Other Split types are not implemented yet.")
return self.unraveled_channel_indices[output_ids]
batch_size = int(output_ids.shape[0])
num_features = int(output_ids.shape[1])
num_inputs = len(self._child_channel_counts)
max_channels = int(max(self._child_channel_counts))
mapped = output_ids.new_zeros((batch_size, num_features, num_inputs, max_channels))
table = self.unraveled_channel_indices.to(device=output_ids.device)
for i, child_channels in enumerate(self._child_channel_counts):
projection = F.one_hot(table[:, i], num_classes=child_channels).to(dtype=output_ids.dtype)
child_routing = output_ids @ projection
mapped[:, :, i, :child_channels] = child_routing
if self.input_is_split:
if isinstance(self.inputs[0], SplitConsecutive):
return rearrange(mapped, "b f i c -> b (i f) 1 c")
elif isinstance(self.inputs[0], SplitInterleaved):
return rearrange(mapped, "b f i c -> b (f i) 1 c")
else:
raise NotImplementedError("Other Split types are not implemented yet.")
return mapped
[docs]
def map_out_mask_to_in_mask(self, mask: Tensor) -> Tensor:
"""Map output mask to input masks.
Args:
mask: Output mask tensor to map to input masks.
Returns:
Tensor: Mapped input masks corresponding to the output mask.
Raises:
NotImplementedError: If split type is not supported.
"""
num_inputs = len(self.inputs) if not self.input_is_split else self.num_splits
if self.input_is_split:
if isinstance(self.inputs[0], SplitConsecutive):
expanded_mask = repeat(mask, "b f -> b f i", i=num_inputs)
return rearrange(expanded_mask, "b f i -> b (i f) 1")
elif isinstance(self.inputs[0], SplitInterleaved):
expanded_mask = repeat(mask, "b f -> b f i", i=num_inputs)
return rearrange(expanded_mask, "b f i -> b (f i) 1")
else:
raise NotImplementedError("Other Split types are not implemented yet.")
else:
return repeat(mask, "b f -> b f i", i=num_inputs)
[docs]
@cached
def log_likelihood(
self,
data: Tensor,
cache: Cache | None = None,
) -> Tensor:
"""Compute log likelihood via outer sum of pairwise combinations.
Args:
data: Input data tensor for computing log likelihood.
cache: Optional cache for storing intermediate computations.
Returns:
Tensor: Log likelihood values with shape [batch_size, out_features, out_channels, num_repetitions].
Raises:
ValueError: If output tensor has invalid number of dimensions.
"""
# initialize cache
lls = self._get_input_log_likelihoods(data, cache)
# Compute the outer sum of pairwise log-likelihoods
# [b, n, m1] + [b, n, m2] -> [b, n, m1, 1] + [b, n, 1, m2] -> [b, n, m1, m2] -> [b, n, 1, m1*m2] ...
output = lls[0]
for i in range(1, len(lls)):
output = output.unsqueeze(3) + lls[i].unsqueeze(2)
if output.ndim == 4:
output = output.view(output.size(0), self.out_shape.features, -1)
elif output.ndim == 5:
output = output.view(output.size(0), self.out_shape.features, -1, self.out_shape.repetitions)
else:
raise ValueError("Invalid number of dimensions")
# View as [b, n, m1 * m2, r]
if self.out_shape.repetitions is None:
output = output.view(output.size(0), self.out_shape.features, self.out_shape.channels)
else:
output = output.view(
output.size(0), self.out_shape.features, self.out_shape.channels, self.out_shape.repetitions
)
return output