Source code for spflow.zoo.cms.joint

"""Wrapper that exposes joint log-likelihood as a single feature.

Some modules (e.g. multivariate leaves like :class:`~spflow.modules.leaves.CLTree`)
return a log-likelihood tensor with a feature axis, where the joint score is
obtained by summing across features.

This wrapper provides a consistent "root-like" view where
``log_likelihood(data)`` returns shape ``(batch, 1, channels, repetitions)``.

Notes:
    This wrapper performs a *tensor reduction* (sum over the feature axis) and is
    not meant to imply any additional probabilistic independence assumptions.
    In particular, it is not a "Product node" / factorization; it simply changes
    how the score is exposed. Sampling and marginalization are delegated to the
    wrapped module unchanged.
"""

from __future__ import annotations

import numpy as np
from torch import Tensor

from spflow.meta.data.scope import Scope
from spflow.modules.module_shape import ModuleShape
from spflow.modules.wrapper.base import Wrapper
from spflow.utils.cache import Cache, cached
from spflow.utils.sampling_context import LeafParamRecord, SamplingContext


[docs] class JointLogLikelihood(Wrapper): """Expose a wrapped module's joint log-likelihood as a single feature.""" def __init__(self, module): super().__init__(module) self.out_shape = ModuleShape(1, module.out_shape.channels, module.out_shape.repetitions) @property def feature_to_scope(self) -> np.ndarray: out = [] for r in range(self.out_shape.repetitions): joined = Scope.join_all(self.module.feature_to_scope[:, r]) out.append(np.array([[joined]])) return np.concatenate(out, axis=1)
[docs] @cached def log_likelihood(self, data: Tensor, cache: Cache | None = None) -> Tensor: ll = self.module.log_likelihood(data, cache=cache) return ll.sum(dim=1, keepdim=True)
[docs] def sample( self, num_samples: int | None = None, data: Tensor | None = None, is_mpe: bool = False, cache: Cache | None = None, return_leaf_params: bool = False, ) -> Tensor | tuple[Tensor, list[LeafParamRecord]]: data = self._prepare_sample_data(num_samples=num_samples, data=data) if cache is None: cache = Cache() context_device = data.device if data is not None else None if context_device is None: try: context_device = next(self.module.parameters()).device except StopIteration: try: context_device = next(self.module.buffers()).device except StopIteration: context_device = None sampling_ctx = SamplingContext( num_samples=data.shape[0], device=context_device, is_mpe=is_mpe, return_leaf_params=return_leaf_params, ) samples = self.module._sample(data=data, cache=cache, sampling_ctx=sampling_ctx) if return_leaf_params: return samples, sampling_ctx.leaf_param_records() return samples
def _sample( self, data: Tensor, sampling_ctx: SamplingContext, cache: Cache, ) -> Tensor: return self.module._sample(data=data, cache=cache, sampling_ctx=sampling_ctx)
[docs] def marginalize(self, marg_rvs: list[int], prune: bool = True, cache: Cache | None = None): child = self.module.marginalize(marg_rvs=marg_rvs, prune=prune, cache=cache) if child is None: return None if child.out_shape.features == 1: return child return JointLogLikelihood(child)