from __future__ import annotations
import numpy as np
import torch
from torch import Tensor
from spflow.meta.data import Scope
from spflow.modules.module import Module
from spflow.modules.module_shape import ModuleShape
from spflow.modules.ops.cat import Cat
from spflow.utils.cache import Cache, cached
from spflow.utils.sampling_context import SamplingContext
[docs]
class Product(Module):
"""Product node implementing factorization via conditional independence.
Computes joint distribution as product of child distributions. Multiple
inputs are automatically concatenated along the feature dimension.
Attributes:
inputs (Module): Input module(s), concatenated if multiple.
"""
[docs]
def __init__(self, inputs: Module | list[Module]) -> None:
"""Initialize product node.
Args:
inputs: Single module or list of modules (concatenated along features).
"""
super().__init__()
# If inputs is a list, ensure concatenation along the feature dimension
if isinstance(inputs, list):
if len(inputs) == 1:
self.inputs = inputs[0]
else:
self.inputs = Cat(inputs=inputs, dim=1)
else:
self.inputs = inputs
# Scope of this product module is equal to the scope of its only input
self.scope = self.inputs.scope
# Shape computation: in_shape = inputs.out_shape, out_shape has 1 feature
input_shape = self.inputs.out_shape
self.in_shape = input_shape
self.out_shape = ModuleShape(1, input_shape.channels, input_shape.repetitions)
@property
def feature_to_scope(self) -> np.ndarray:
out = []
for r in range(self.out_shape.repetitions):
joined_scope = Scope.join_all(self.inputs.feature_to_scope[:, r])
out.append(np.array([[joined_scope]]))
return np.concatenate(out, axis=1)
[docs]
@cached
def log_likelihood(
self,
data: Tensor,
cache: Cache | None = None,
) -> Tensor:
"""Compute log likelihood by summing child log-likelihoods across features.
Args:
data: Input data tensor.
cache: Optional cache for storing intermediate results.
Returns:
Tensor: Log likelihood values.
"""
# compute child log-likelihoods
ll = self.inputs.log_likelihood(
data,
cache=cache,
)
# multiply children (sum in log-space)
result = torch.sum(ll, dim=1, keepdim=True)
return result
def _sample(
self,
data: Tensor,
sampling_ctx: SamplingContext,
cache: Cache,
) -> Tensor:
"""Generate samples by delegating to input module.
Args:
num_samples: Number of samples to generate.
data: Optional data tensor to fill with samples.
is_mpe: Whether to perform most probable explanation.
cache: Optional cache for storing intermediate results.
sampling_ctx: Optional sampling context.
Returns:
Tensor: Generated samples.
"""
sampling_ctx.validate_sampling_context(
num_samples=data.shape[0],
num_features=self.out_shape.features,
num_channels=self.out_shape.channels,
num_repetitions=self.out_shape.repetitions,
allowed_feature_widths=(1, self.out_shape.features),
)
sampling_ctx.broadcast_feature_width(
target_features=self.inputs.out_shape.features,
allow_from_one=True,
)
# Delegate to input module for actual sampling
self.inputs._sample(
data=data,
cache=cache,
sampling_ctx=sampling_ctx,
)
return data
def _expectation_maximization_step(
self,
data: Tensor,
bias_correction: bool = True,
*,
cache: Cache,
) -> None:
"""EM step (delegates to input, no learnable parameters).
Args:
data: Input data tensor for EM step.
bias_correction: Whether to apply bias correction. Defaults to True.
cache: Optional cache for storing intermediate results.
"""
# Product has no learnable parameters, delegate to input
self.inputs._expectation_maximization_step(data, cache=cache, bias_correction=bias_correction)
[docs]
def marginalize(
self,
marg_rvs: list[int],
prune: bool = True,
cache: Cache | None = None,
) -> Product | Module | None:
"""Marginalize out specified random variables.
Args:
marg_rvs: List of random variable indices to marginalize.
prune: Whether to prune unnecessary nodes.
cache: Optional cache for storing intermediate results.
Returns:
Product | Module | None: Marginalized module or None if fully marginalized.
"""
# compute layer scope (same for all outputs)
layer_scope = self.scope
marg_child = None
mutual_rvs = set(layer_scope.query).intersection(set(marg_rvs))
# layer scope is being fully marginalized over
if len(mutual_rvs) == len(layer_scope.query):
# passing this loop means marginalizing over the whole scope of this branch
return None
# node scope is being partially marginalized
elif mutual_rvs:
# marginalize child modules
marg_child_layer = self.inputs.marginalize(marg_rvs, prune=prune, cache=cache)
# if marginalized child is not None
if marg_child_layer:
marg_child = marg_child_layer
else:
marg_child = self.inputs
if marg_child is None:
return None
elif prune and marg_child.out_shape.features == 1:
return marg_child
else:
return Product(inputs=marg_child)