Source code for spflow.utils.signed_semiring
"""Real signed-semirings utilities for numerically stable circuit evaluation.
This module implements a sign-aware log-absolute-value representation:
any real value `x` is represented as a pair `(log|x|, sign(x))`.
This is used to evaluate circuits that may contain negative parameters (e.g.,
SignedSum) while avoiding underflow/overflow when magnitudes are very small/large.
"""
from __future__ import annotations
import torch
from torch import Tensor
[docs]
def sign_of(x: Tensor) -> Tensor:
"""Return sign(x) in {-1, 0, +1} as an integer tensor."""
return torch.sign(x).to(dtype=torch.int8)
[docs]
def logabs_of(x: Tensor, eps: float = 0.0) -> Tensor:
"""Return log(|x|), with optional epsilon to avoid log(0).
Args:
x: Input tensor.
eps: If > 0, computes log(|x| + eps).
"""
if eps > 0.0:
return torch.log(torch.abs(x) + x.new_tensor(eps))
return torch.log(torch.abs(x))
[docs]
def signed_logsumexp(
logabs_terms: Tensor,
sign_terms: Tensor,
dim: int,
keepdim: bool = False,
eps: float = 0.0,
) -> tuple[Tensor, Tensor]:
"""Compute log|Σ_i s_i exp(a_i)| and sign of the sum in a stable way.
Args:
logabs_terms: Log-absolute-values `a_i` of the terms.
sign_terms: Signs `s_i` of the terms in {-1, 0, +1}. Must be broadcastable.
dim: Dimension to reduce over.
keepdim: Whether to keep the reduced dimension.
eps: Additive epsilon to avoid log(0) in edge cases.
Returns:
(logabs_sum, sign_sum)
"""
if logabs_terms.numel() == 0:
raise ValueError("signed_logsumexp requires at least one term.")
# m = max(a_i) for stability (treat -inf properly)
m = torch.max(logabs_terms, dim=dim, keepdim=True).values
# If all terms are -inf along `dim`, the sum is exactly 0.
# Avoid exp(nan) from (-inf) - (-inf).
all_neg_inf = torch.isneginf(m)
# exp(a_i - m) is in [0, 1] for finite a_i
scaled = torch.exp(logabs_terms - m)
if all_neg_inf.any():
scaled = torch.where(all_neg_inf, torch.zeros_like(scaled), scaled)
signed_scaled = scaled * sign_terms.to(dtype=scaled.dtype)
s = torch.sum(signed_scaled, dim=dim, keepdim=True)
# Handle zeros safely
sign_s = sign_of(s)
abs_s = torch.abs(s)
if eps > 0.0:
abs_s = abs_s + abs_s.new_tensor(eps)
logabs_s = torch.log(abs_s)
out_logabs = m + logabs_s
if all_neg_inf.any():
out_logabs = torch.where(all_neg_inf, torch.full_like(out_logabs, float("-inf")), out_logabs)
if not keepdim:
out_logabs = out_logabs.squeeze(dim)
sign_s = sign_s.squeeze(dim)
return out_logabs, sign_s