from __future__ import annotations
import numpy as np
import torch
from torch import Tensor
from spflow.meta.data import Scope
from spflow.modules.module import Module
from spflow.modules.module_shape import ModuleShape
from spflow.modules.ops.cat import Cat
from spflow.utils.cache import Cache, cached
from spflow.utils.sampling_context import SamplingContext, init_default_sampling_context
[docs]
class Product(Module):
"""Product node implementing factorization via conditional independence.
Computes joint distribution as product of child distributions. Multiple
inputs are automatically concatenated along the feature dimension.
Attributes:
inputs (Module): Input module(s), concatenated if multiple.
"""
[docs]
def __init__(self, inputs: Module | list[Module]) -> None:
"""Initialize product node.
Args:
inputs: Single module or list of modules (concatenated along features).
"""
super().__init__()
# If inputs is a list, ensure concatenation along the feature dimension
if isinstance(inputs, list):
if len(inputs) == 1:
self.inputs = inputs[0]
else:
self.inputs = Cat(inputs=inputs, dim=1)
else:
self.inputs = inputs
# Scope of this product module is equal to the scope of its only input
self.scope = self.inputs.scope
# Shape computation: in_shape = inputs.out_shape, out_shape has 1 feature
input_shape = self.inputs.out_shape
self.in_shape = input_shape
self.out_shape = ModuleShape(1, input_shape.channels, input_shape.repetitions)
@property
def feature_to_scope(self) -> np.ndarray:
out = []
for r in range(self.out_shape.repetitions):
joined_scope = Scope.join_all(self.inputs.feature_to_scope[:, r])
out.append(np.array([[joined_scope]]))
return np.concatenate(out, axis=1)
[docs]
@cached
def log_likelihood(
self,
data: Tensor,
cache: Cache | None = None,
) -> Tensor:
"""Compute log likelihood by summing child log-likelihoods across features.
Args:
data: Input data tensor.
cache: Optional cache for storing intermediate results.
Returns:
Tensor: Log likelihood values.
"""
# compute child log-likelihoods
ll = self.inputs.log_likelihood(
data,
cache=cache,
)
# multiply children (sum in log-space)
result = torch.sum(ll, dim=1, keepdim=True)
return result
[docs]
def sample(
self,
num_samples: int | None = None,
data: Tensor | None = None,
is_mpe: bool = False,
cache: Cache | None = None,
sampling_ctx: SamplingContext | None = None,
) -> Tensor:
"""Generate samples by delegating to input module.
Args:
num_samples: Number of samples to generate.
data: Optional data tensor to fill with samples.
is_mpe: Whether to perform most probable explanation.
cache: Optional cache for storing intermediate results.
sampling_ctx: Optional sampling context.
Returns:
Tensor: Generated samples.
"""
# Handle num_samples case (create empty data tensor)
if data is None:
if num_samples is None:
num_samples = 1
data = torch.full((num_samples, len(self.scope.query)), torch.nan, device=self.device)
# Initialize sampling context if not provided
sampling_ctx = init_default_sampling_context(sampling_ctx, data.shape[0], data.device)
# Expand mask and channels to match input module shape
mask = sampling_ctx.mask.expand(data.shape[0], self.inputs.out_shape.features)
channel_index = sampling_ctx.channel_index.expand(data.shape[0], self.inputs.out_shape.features)
sampling_ctx.update(channel_index=channel_index, mask=mask)
# Delegate to input module for actual sampling
self.inputs.sample(
data=data,
is_mpe=is_mpe,
cache=cache,
sampling_ctx=sampling_ctx,
)
return data
[docs]
def expectation_maximization(
self,
data: Tensor,
bias_correction: bool = True,
cache: Cache | None = None,
) -> None:
"""EM step (delegates to input, no learnable parameters).
Args:
data: Input data tensor for EM step.
bias_correction: Whether to apply bias correction. Defaults to True.
cache: Optional cache for storing intermediate results.
"""
# Product has no learnable parameters, delegate to input
self.inputs.expectation_maximization(data, cache=cache, bias_correction=bias_correction)
[docs]
def maximum_likelihood_estimation(
self,
data: Tensor,
weights: Tensor | None = None,
cache: Cache | None = None,
) -> None:
"""MLE step (delegates to input, no learnable parameters).
Args:
data: Input data tensor for MLE step.
weights: Optional weights for weighted MLE.
cache: Optional cache for storing intermediate results.
"""
# Product has no learnable parameters, delegate to input
self.inputs.maximum_likelihood_estimation(
data,
weights=weights,
cache=cache,
)
[docs]
def marginalize(
self,
marg_rvs: list[int],
prune: bool = True,
cache: Cache | None = None,
) -> Product | Module | None:
"""Marginalize out specified random variables.
Args:
marg_rvs: List of random variable indices to marginalize.
prune: Whether to prune unnecessary nodes.
cache: Optional cache for storing intermediate results.
Returns:
Product | Module | None: Marginalized module or None if fully marginalized.
"""
# compute layer scope (same for all outputs)
layer_scope = self.scope
marg_child = None
mutual_rvs = set(layer_scope.query).intersection(set(marg_rvs))
# layer scope is being fully marginalized over
if len(mutual_rvs) == len(layer_scope.query):
# passing this loop means marginalizing over the whole scope of this branch
return None
# node scope is being partially marginalized
elif mutual_rvs:
# marginalize child modules
marg_child_layer = self.inputs.marginalize(marg_rvs, prune=prune, cache=cache)
# if marginalized child is not None
if marg_child_layer:
marg_child = marg_child_layer
else:
marg_child = self.inputs
if marg_child is None:
return None
elif prune and marg_child.out_shape.features == 1:
return marg_child
else:
return Product(inputs=marg_child)