from __future__ import annotations
from typing import Optional
import numpy as np
import torch
from torch import Tensor, nn
from spflow.exceptions import InvalidParameterError
from spflow.meta.data import Scope
from spflow.modules.module import Module
from spflow.modules.module_shape import ModuleShape
from spflow.utils.cache import Cache, cached
from spflow.utils.sampling_context import (
SamplingContext,
)
[docs]
class Cat(Module):
[docs]
def __init__(self, inputs: list[Module], dim: int = -1):
"""Initialize concatenation operation.
Args:
inputs: Modules to concatenate.
dim: Concatenation dimension (0=batch, 1=feature, 2=channel).
"""
super().__init__()
self.inputs = nn.ModuleList(inputs)
self.dim = dim
if self.dim == 1:
# Check if all inputs have the same number of channels
if not all(
[module.out_shape.channels == self.inputs[0].out_shape.channels for module in self.inputs]
):
raise ValueError("All inputs must have the same number of channels.")
# Check that all scopes are disjoint
if not Scope.all_pairwise_disjoint([module.scope for module in self.inputs]):
raise ValueError("All inputs must have disjoint scopes.")
# Scope is the join of all input scopes
self._scope = Scope.join_all([inp.scope for inp in self.inputs])
elif self.dim == 2:
# Check if all inputs have the same number of features and scopes
if not all(
[module.out_shape.features == self.inputs[0].out_shape.features for module in self.inputs]
):
raise ValueError("All inputs must have the same number of features.")
if not Scope.all_equal([module.scope for module in self.inputs]):
raise ValueError("All inputs must have the same scope.")
# Scope is the same as all inputs
self._scope = self.inputs[0].scope
else:
raise ValueError("Invalid dimension for concatenation.")
# Shape computation
self.in_shape = self.inputs[0].out_shape
if self.dim == 1:
out_features = sum([module.out_shape.features for module in self.inputs])
out_channels = self.inputs[0].out_shape.channels
else: # dim == 2
out_features = self.inputs[0].out_shape.features
out_channels = sum([module.out_shape.channels for module in self.inputs])
self.out_shape = ModuleShape(out_features, out_channels, self.inputs[0].out_shape.repetitions)
@property
def feature_to_scope(self) -> np.ndarray:
if self.dim == 1:
# Concatenate along features dimension (axis=0) since we're concatenating features
return np.concatenate([module.feature_to_scope for module in self.inputs], axis=0)
else:
return self.inputs[0].feature_to_scope
def extra_repr(self) -> str:
return f"{super().extra_repr()}, dim={self.dim}"
[docs]
@cached
def log_likelihood(
self,
data: Tensor,
cache: Cache | None = None,
) -> Tensor:
"""Compute log likelihood by concatenating input log-likelihoods.
Args:
data: Input data tensor.
cache: Optional cache for storing intermediate results.
Returns:
Tensor: Concatenated log-likelihood tensor.
"""
# get log likelihoods for all inputs
lls = []
for input_module in self.inputs:
input_ll = input_module.log_likelihood(data, cache=cache)
lls.append(input_ll)
# Concatenate log likelihoods
output = torch.cat(lls, dim=self.dim)
return output
def _sample(
self,
data: Tensor,
sampling_ctx: SamplingContext,
cache: Cache,
) -> Tensor:
"""Generate samples by delegating to concatenated inputs.
Args:
num_samples: Number of samples to generate.
data: Optional data tensor to store samples.
is_mpe: Whether to perform most probable explanation sampling.
cache: Optional cache for storing intermediate results.
sampling_ctx: Sampling context for controlling sample generation.
Returns:
Tensor: Generated samples tensor.
"""
# Prepare data tensor
sampling_ctx.validate_sampling_context(
num_samples=data.shape[0],
num_features=self.out_shape.features,
num_channels=self.out_shape.channels,
num_repetitions=self.out_shape.repetitions,
allowed_feature_widths=(1, self.out_shape.features),
)
sampling_ctx.broadcast_feature_width(target_features=self.out_shape.features, allow_from_one=True)
if self.dim == 1:
ranges: list[tuple[int, int]] = []
feature_offset = 0
for module in self.inputs:
num_features = module.out_shape.features
ranges.append((feature_offset, feature_offset + num_features))
feature_offset += num_features
per_module = sampling_ctx.slice_feature_ranges(ranges=ranges)
channel_index_per_module = [pair[0] for pair in per_module]
mask_per_module = [pair[1] for pair in per_module]
elif self.dim == 2:
per_module = sampling_ctx.route_channel_offsets(
child_channel_counts=[int(module.out_shape.channels) for module in self.inputs],
)
if sampling_ctx.is_differentiable:
ownership = torch.stack([child_mask for _, child_mask in per_module], dim=0).sum(dim=0)
invalid = (ownership != 1) & sampling_ctx.mask
if invalid.any():
raise InvalidParameterError(
"Cat(dim=2) differentiable routing must select exactly one child per active (batch, feature)."
)
channel_index_per_module = [pair[0] for pair in per_module]
mask_per_module = [pair[1] for pair in per_module]
else:
raise ValueError("Invalid dimension for concatenation.")
# Iterate over inputs
for i in range(len(self.inputs)):
input_module = self.inputs[i]
sampling_ctx_copy = sampling_ctx.with_routing(
channel_index=channel_index_per_module[i],
mask=mask_per_module[i],
)
input_module._sample(
data=data,
cache=cache,
sampling_ctx=sampling_ctx_copy,
)
return data
[docs]
def marginalize(
self,
marg_rvs: list[int],
prune: bool = True,
cache: Cache | None = None,
) -> Optional["Module"]:
"""Marginalize out specified random variables.
Args:
marg_rvs: List of random variable indices to marginalize.
prune: Whether to prune unnecessary modules after marginalization.
cache: Optional cache for storing intermediate results.
Returns:
Optional[Module]: Marginalized module or None if fully marginalized.
"""
# compute module scope (same for all outputs)
module_scope = self.scope
mutual_rvs = set(module_scope.query).intersection(set(marg_rvs))
# Node scope is only being partially marginalized
if mutual_rvs:
inputs = []
# marginalize child modules
for input_module in self.inputs:
marg_child_module = input_module.marginalize(marg_rvs, prune=prune, cache=cache)
# if marginalized child is not None
if marg_child_module:
inputs.append(marg_child_module)
# if all children were marginalized, return None
if len(inputs) == 0:
return None
# if only a single input survived marginalization, return it if pruning is enabled
if prune and len(inputs) == 1:
return inputs[0]
return Cat(inputs=inputs, dim=self.dim)
else:
return self