Source code for spflow.modules.sos.socs

from __future__ import annotations

from collections.abc import Iterable
from typing import cast

import numpy as np
import torch
from einops import reduce
from torch import Tensor

from spflow.exceptions import ShapeError, UnsupportedOperationError
from spflow.meta.data.scope import Scope
from spflow.modules.module import Module
from spflow.modules.ops.cat import Cat
from spflow.modules.products.product import Product
from spflow.modules.sums.signed_sum import SignedSum
from spflow.modules.sums.sum import Sum
from spflow.utils.cache import Cache, cached
from spflow.utils.inner_product import inner_product_matrix, log_self_inner_product_scalar
from spflow.utils.sampling_context import SamplingContext


def _is_signed_categorical(module: Module) -> bool:
    return module.__class__.__name__ == "SignedCategorical" and hasattr(module, "signed_logabs_and_sign")


def _contains_signed_sum(module: Module) -> bool:
    for m in module.modules():
        if isinstance(m, SignedSum) or _is_signed_categorical(cast(Module, m)):
            return True
    return False


def _signed_eval(module: Module, data: Tensor, cache: Cache) -> tuple[Tensor, Tensor]:
    """Evaluate `module` as a real function in (log|·|, sign) form.

    Returns:
        logabs, sign of shape (B, F, C, R).
    """
    cached = cache.get("signed_eval", module)
    if cached is not None:
        return cached

    if hasattr(module, "signed_logabs_and_sign"):
        out = module.signed_logabs_and_sign(data, cache=cache)  # type: ignore[attr-defined]
        cache.set("signed_eval", module, out)
        return out

    # Leaves and monotone internal modules: use log_likelihood as log-abs and sign=+1.
    if isinstance(module, (Sum,)):
        logv = module.log_likelihood(data, cache=cache)
        sign = torch.ones_like(logv, dtype=torch.int8)
        out = (logv, sign)
        cache.set("signed_eval", module, out)
        return out

    if isinstance(module, Cat):
        parts = [_signed_eval(cast(Module, child), data, cache) for child in module.inputs]
        logabs = torch.cat([p[0] for p in parts], dim=module.dim)
        sign = torch.cat([p[1] for p in parts], dim=module.dim)
        out = (logabs, sign)
        cache.set("signed_eval", module, out)
        return out

    if isinstance(module, Product):
        child_logabs, child_sign = _signed_eval(cast(Module, module.inputs), data, cache)
        # Multiply over features => add log-abs, multiply signs
        logabs = torch.sum(child_logabs, dim=1, keepdim=True)
        sign = torch.prod(child_sign.to(dtype=torch.int16), dim=1, keepdim=True).to(dtype=torch.int8)
        out = (logabs, sign)
        cache.set("signed_eval", module, out)
        return out

    # Default: non-negative module
    logv = module.log_likelihood(data, cache=cache)
    sign = torch.ones_like(logv, dtype=torch.int8)
    out = (logv, sign)
    cache.set("signed_eval", module, out)
    return out


[docs] class SOCS(Module): """Sum of Compatible Squares (SOCS) wrapper module. Represents a non-negative density of the form: c(x) = Σ_i c_i(x)^2 p(x) = c(x) / Z, where Z = ∫ c(x) dx = Σ_i ∫ c_i(x)^2 dx Notes: - `log_likelihood()` is supported for signed components built with `SignedSum`. - `sample()` is supported only when all components are standard monotone SPFlow PCs (i.e., do not contain `SignedSum`), using a Metropolis–Hastings independence sampler. """ def __init__(self, components: list[Module]) -> None: super().__init__() if len(components) < 1: raise ValueError("SOCS requires at least one component.") # Validate scope equality and compatible output shapes. scope = components[0].scope out_shape0 = components[0].out_shape for c in components: if not Scope.all_equal([scope, c.scope]): raise ShapeError("All SOCS components must have identical scope.") if tuple(c.out_shape) != tuple(out_shape0): raise ShapeError( "All SOCS components must have identical out_shape; " f"got {tuple(out_shape0)} vs {tuple(c.out_shape)}." ) self.components = torch.nn.ModuleList(components) self.scope = scope self.in_shape = components[0].in_shape self.out_shape = components[0].out_shape @property def feature_to_scope(self) -> np.ndarray: return cast(Module, self.components[0]).feature_to_scope def _log_partition(self, cache: Cache) -> Tensor: """Compute log Z per output entry (shape: (F, C, R)).""" cached = cache.get("socs_logZ", self) if cached is not None: # Keep a convenient handle for downstream inspection/debugging. cache.extras["socs_logZ"] = cached return cast(Tensor, cached) # Z[f,c,r] = Σ_i ∫ c_{i,f,c,r}(x)^2 dx; # each component contributes the diagonal of its self inner-product. z_parts = [] for comp in self.components: k = inner_product_matrix(cast(Module, comp), cast(Module, comp), cache=cache) # (F, C, C, R) diag = torch.diagonal(k, dim1=1, dim2=2) # (F, R, C) z_parts.append(diag.permute(0, 2, 1)) # (F, C, R) Z = reduce(torch.stack(z_parts, dim=0), "n f c r -> f c r", "sum") logZ = torch.log(torch.clamp(Z, min=1e-30)) cache.set("socs_logZ", self, logZ) cache.extras["socs_logZ"] = logZ return logZ
[docs] @cached def log_likelihood(self, data: Tensor, cache: Cache | None = None) -> Tensor: # type: ignore[override] # log c(x) = log Σ_i exp(2 log|c_i(x)|) (elementwise over output entries) comp_terms = [] for comp in self.components: logabs, _sign = _signed_eval(cast(Module, comp), data, cache) comp_terms.append(2.0 * logabs) stacked = torch.stack(comp_terms, dim=0) # (r, B, F, C, R) log_c = torch.logsumexp(stacked, dim=0) # (B, F, C, R) logZ = self._log_partition(cache).to(dtype=log_c.dtype, device=log_c.device).unsqueeze(0) return log_c - logZ
def _expectation_maximization_step( self, data: Tensor, bias_correction: bool = True, *, cache: Cache, ) -> None: raise UnsupportedOperationError("SOCS does not support expectation-maximization.")
[docs] def marginalize( self, marg_rvs: list[int], prune: bool = True, cache: Cache | None = None, ) -> Module | None: # Marginalize each component and rebuild SOCS if possible. new_components: list[Module] = [] for comp in self.components: m = cast(Module, comp).marginalize(marg_rvs, prune=prune, cache=cache) if m is None: return None new_components.append(m) if prune and len(new_components) == 1: # Keep SOCS wrapper; semantics differ from the raw component. return SOCS(new_components) return SOCS(new_components)
[docs] def sample( self, num_samples: int | None = None, data: Tensor | None = None, is_mpe: bool = False, cache: Cache | None = None, ) -> Tensor: data = self._prepare_sample_data(num_samples=num_samples, data=data) if is_mpe: raise UnsupportedOperationError("SOCS.mpe() is not supported (use MAP on components if needed).") # Only unconditional sampling for now (all NaNs) if torch.isfinite(data).any(): raise UnsupportedOperationError( "SOCS.sample() does not support conditional sampling with evidence yet." ) if tuple(self.out_shape) != (1, 1, 1): raise UnsupportedOperationError( "SOCS.sample() currently supports only scalar-output circuits " "(out_shape.features==1, out_shape.channels==1, out_shape.repetitions==1)." ) return super().sample( num_samples=None, data=data, is_mpe=is_mpe, cache=cache, )
def _sample( self, data: Tensor, sampling_ctx: SamplingContext, cache: Cache, ) -> Tensor: if sampling_ctx.is_differentiable: raise UnsupportedOperationError( "SOCS.sample() does not support differentiable routing yet." ) if sampling_ctx.is_mpe: raise UnsupportedOperationError("SOCS.mpe() is not supported (use MAP on components if needed).") # Only unconditional sampling for now (all NaNs) if torch.isfinite(data).any(): raise UnsupportedOperationError( "SOCS.sample() does not support conditional sampling with evidence yet." ) if tuple(self.out_shape) != (1, 1, 1): raise UnsupportedOperationError( "SOCS.sample() currently supports only scalar-output circuits " "(out_shape.features==1, out_shape.channels==1, out_shape.repetitions==1)." ) num_samples = data.shape[0] # Mixture over components with weights proportional to Z_i logZs = torch.stack( [log_self_inner_product_scalar(cast(Module, c), cache=cache) for c in self.components] ) comp_idx = torch.distributions.Categorical(logits=logZs).sample((num_samples,)) # MCMC settings (can be overridden via cache.extras when cache is provided). cache_extras = cache.extras steps_after_burn_in = int(cache_extras.get("socs_mcmc_steps", cache_extras.get("socs_mh_steps", 50))) burn_in = int(cache_extras.get("socs_mcmc_burn_in", cache_extras.get("socs_mh_burn_in", 10))) if steps_after_burn_in < 1: raise ValueError("socs_mcmc_steps must be >= 1.") if burn_in < 0: raise ValueError("socs_mcmc_burn_in must be >= 0.") total_steps = burn_in + steps_after_burn_in def _joint_ll(mod: Module, x: Tensor) -> Tensor: # Do not reuse the traversal cache across different MCMC states. ll = mod.log_likelihood(x, cache=None) return reduce(ll, "b f 1 1 -> b", "sum") def _log_target_signed(mod: Module, x: Tensor) -> Tensor: # Same: the Cache implementation is per-module (not per-data), so it must # not be re-used when evaluating different x values in MCMC. eval_cache = Cache() logabs, _sign = _signed_eval(mod, x, eval_cache) return reduce(2.0 * logabs, "b f 1 1 -> b", "sum") # Sample per-component with an independence MH kernel: # target π(x) ∝ c_i(x)^2, proposal q(x) from a monotone PC (either c_i itself or abs-weight proxy). out = data.clone() for i, comp in enumerate(self.components): mask = comp_idx == i if not mask.any(): continue n_i = int(mask.sum().item()) comp_mod = cast(Module, comp) has_signed = _contains_signed_sum(comp_mod) if has_signed: # Local import to avoid a circular import: learn.build_socs -> SOCS. from spflow.learn.build_socs import build_abs_weight_proposal proposal = build_abs_weight_proposal(comp_mod) else: proposal = comp_mod x = proposal.sample(num_samples=n_i) log_q_x = _joint_ll(proposal, x) log_t_x = _log_target_signed(comp_mod, x) for _t in range(total_steps): x_p = proposal.sample(num_samples=n_i) log_q_p = _joint_ll(proposal, x_p) log_t_p = _log_target_signed(comp_mod, x_p) log_alpha = (log_t_p - log_t_x) + (log_q_x - log_q_p) u = torch.log(torch.rand_like(log_alpha)) accept = u < torch.minimum(log_alpha, torch.zeros_like(log_alpha)) x = torch.where(accept.unsqueeze(1), x_p, x) log_q_x = torch.where(accept, log_q_p, log_q_x) log_t_x = torch.where(accept, log_t_p, log_t_x) out[mask] = x return out