Source code for spflow.modules.sums.signed_sum

from __future__ import annotations

from typing import cast

import numpy as np
import torch
from einops import rearrange, repeat
from torch import Tensor, nn

from spflow.exceptions import ShapeError, UnsupportedOperationError
from spflow.modules.module import Module
from spflow.modules.module_shape import ModuleShape
from spflow.modules.ops.cat import Cat
from spflow.utils.cache import Cache, cached
from spflow.utils.sampling_context import (
    SamplingContext,
    update_channel_index_strict,
)
from spflow.utils.signed_semiring import signed_logsumexp, sign_of


[docs] class SignedSum(Module): """Linear-combination node that allows negative, non-normalized weights. This node is *not* a probabilistic mixture node. It represents a real-valued linear combination of input channels: y = Σ_j w_j * x_j where weights may be negative and do not need to sum to one. Notes: - `SignedSum` does not implement `log_likelihood()` because its output may be negative (log is undefined). Use SOCS or signed evaluation utilities for inference. - `sample()` is only supported when all weights are non-negative and no evidence is present, in which case it behaves like an unnormalized mixture over inputs. """ def __init__( self, inputs: Module | list[Module], out_channels: int = 1, num_repetitions: int = 1, weights: Tensor | None = None, ) -> None: super().__init__() if not inputs: raise ValueError("'SignedSum' requires at least one input.") if isinstance(inputs, list): if len(inputs) == 1: self.inputs = inputs[0] else: self.inputs = Cat(inputs=inputs, dim=2) else: self.inputs = inputs self.sum_dim = 1 # sum over input channels self.scope = self.inputs.scope self.in_shape = self.inputs.out_shape self.out_shape = ModuleShape( features=self.in_shape.features, channels=out_channels, repetitions=num_repetitions ) self.weights_shape = ( self.in_shape.features, self.in_shape.channels, self.out_shape.channels, self.out_shape.repetitions, ) if weights is None: weights = torch.randn(self.weights_shape, dtype=torch.get_default_dtype()) else: weights = torch.as_tensor(weights, dtype=torch.get_default_dtype()) if tuple(weights.shape) != tuple(self.weights_shape): raise ShapeError( f"Invalid shape for weights: was {tuple(weights.shape)} but expected {self.weights_shape}." ) self.weights = nn.Parameter(weights) @property def feature_to_scope(self) -> np.ndarray: return self.inputs.feature_to_scope def extra_repr(self) -> str: return f"{super().extra_repr()}, weights={self.weights_shape}"
[docs] def log_likelihood(self, data: Tensor, cache: Cache | None = None) -> Tensor: # type: ignore[override] raise UnsupportedOperationError( "SignedSum does not define log_likelihood() because outputs may be negative. " "Use SOCS (sum of squares) or signed evaluation utilities." )
def _expectation_maximization_step( self, data: Tensor, bias_correction: bool = True, *, cache: Cache, ) -> None: raise UnsupportedOperationError("SignedSum does not support expectation-maximization.")
[docs] def marginalize( self, marg_rvs: list[int], prune: bool = True, cache: Cache | None = None, ) -> Module | None: raise UnsupportedOperationError("SignedSum.marginalize() is not supported.")
[docs] @cached def signed_logabs_and_sign( self, data: Tensor, cache: Cache | None = None, ) -> tuple[Tensor, Tensor]: """Evaluate this node in ``(log|·|, sign)`` form. Returns: logabs: Tensor of shape (B, F, OC, R) sign: Tensor of shape (B, F, OC, R) in {-1,0,+1} """ # Cache keying: re-use `Cache` internal dicts with a custom method name. cached_ = cache.get("signed_logabs_and_sign", self) if cached_ is not None: return cached_ # Evaluate children in signed semiring child = self.inputs if hasattr(child, "signed_logabs_and_sign"): child_logabs, child_sign = child.signed_logabs_and_sign(data, cache=cache) # type: ignore[attr-defined] else: # Recursively evaluate mixed sub-graphs (e.g., Product/Cat containing SignedSum leaves). from spflow.modules.sos.socs import _signed_eval child_logabs, child_sign = _signed_eval(cast(Module, child), data, cache) # child_* shapes: (B, F, IC, R) if child_logabs.dim() != 4: raise ShapeError( f"Expected child signed evaluation to be 4D (B,F,C,R), got shape {tuple(child_logabs.shape)}." ) ll = rearrange(child_logabs, "b f ci r -> b f ci 1 r") ss = rearrange(child_sign, "b f ci r -> b f ci 1 r").to(dtype=torch.int8) w = rearrange(self.weights, "f ci co r -> 1 f ci co r") w_sign = sign_of(w).to(dtype=torch.int8) # Avoid log(0) for zero weights; zero-weight terms should contribute nothing. w_logabs = torch.log(torch.abs(w).clamp_min(1e-30)) # Terms: w * x = sign(w)*sign(x) * exp(log|w|+log|x|) term_sign = (w_sign * ss).to(dtype=torch.int8) term_logabs = ll + w_logabs # broadcast over batch out_logabs, out_sign = signed_logsumexp( logabs_terms=term_logabs, sign_terms=term_sign, dim=self.sum_dim + 1, # reduce IC keepdim=False, eps=0.0, ) cache.set("signed_logabs_and_sign", self, (out_logabs, out_sign)) return out_logabs, out_sign
def _sample( self, data: Tensor, sampling_ctx: SamplingContext, cache: Cache, ) -> Tensor: """Sample from the unnormalized non-negative subset of SignedSum. Only supported if all weights are >= 0 and no evidence is provided. """ if sampling_ctx.is_differentiable: raise UnsupportedOperationError( "SignedSum.sample() does not support differentiable routing yet." ) if torch.isfinite(data).any(): raise UnsupportedOperationError( "SignedSum.sample() does not support conditional sampling with evidence." ) if (self.weights < 0).any(): raise UnsupportedOperationError( "SignedSum.sample() is only supported when all weights are non-negative." ) # Only supports scalar feature routing like Sum: choose input-channel per feature. # We treat weights as proportional probabilities. w = self.weights if w.dim() != 4: raise ShapeError(f"Expected weights to be 4D, got shape {tuple(w.shape)}.") if self.out_shape.repetitions != 1: raise UnsupportedOperationError( "SignedSum.sample() currently supports num_repetitions == 1 only." ) sampling_ctx.validate_sampling_context( num_samples=data.shape[0], num_features=self.out_shape.features, num_channels=self.out_shape.channels, num_repetitions=self.out_shape.repetitions, allowed_feature_widths=(1, self.out_shape.features), ) # Current output channel selection from parent batch_size = int(data.shape[0]) num_weight_features = int(w.shape[0]) in_channels_total = w.shape[1] context_features = int(sampling_ctx.channel_index.shape[1]) if context_features == num_weight_features: oidx = repeat(sampling_ctx.channel_index, "b f -> b f ci 1", ci=in_channels_total) elif context_features == 1: oidx = repeat( sampling_ctx.channel_index, "b 1 -> b f ci 1", f=num_weight_features, ci=in_channels_total ) else: raise ShapeError( f"Expected channel_index feature width 1 or {num_weight_features}, got {context_features}." ) w_sel = repeat(w[..., 0], "f ci co -> b f ci co", b=batch_size) w_sel = w_sel.gather(dim=3, index=oidx) w_sel = rearrange(w_sel, "b f ci 1 -> b f ci") # Normalize over input channels probs = w_sel / w_sel.sum(dim=2, keepdim=True).clamp_min(1e-12) if sampling_ctx.is_mpe: new_channel_index = torch.argmax(probs, dim=-1) else: new_channel_index = torch.distributions.Categorical(probs=probs).sample() update_channel_index_strict(sampling_ctx, new_channel_index) self.inputs._sample(data=data, cache=cache, sampling_ctx=sampling_ctx) return data