"""Shared inner-/triple-product utilities for probabilistic circuits.
This module implements the dynamic programs used by SOS/SOCS normalization:
- Pairwise inner products: ∫ a(x) b(x) dx
- Triple products: ∫ a(x) b(x) c(x) dx
Both are computed bottom-up using circuit structure (Cat/Product/(Signed)Sum)
and analytic leaf integrals when available.
Two wrappers use this core:
- `spflow.utils.inner_product` (for `spflow.modules.sos`)
- `spflow.utils.inner_product` (canonical SOS/SOCS entry point)
"""
from __future__ import annotations
from collections.abc import Sequence
from typing import cast
import torch
from einops import rearrange
from torch import Tensor
from spflow.exceptions import ShapeError, UnsupportedOperationError
from spflow.modules.leaves.bernoulli import Bernoulli
from spflow.modules.leaves.binomial import Binomial
from spflow.modules.leaves.categorical import Categorical
from spflow.modules.leaves.cltree import CLTree
from spflow.modules.leaves.exponential import Exponential
from spflow.modules.leaves.gamma import Gamma
from spflow.modules.leaves.geometric import Geometric
from spflow.modules.leaves.histogram import Histogram
from spflow.modules.leaves.hypergeometric import Hypergeometric
from spflow.modules.leaves.laplace import Laplace
from spflow.modules.leaves.leaf import LeafModule
from spflow.modules.leaves.log_normal import LogNormal
from spflow.modules.leaves.negative_binomial import NegativeBinomial
from spflow.modules.leaves.normal import Normal
from spflow.modules.leaves.piecewise_linear import PiecewiseLinear
from spflow.modules.leaves.poisson import Poisson
from spflow.modules.leaves.uniform import Uniform
from spflow.modules.module import Module
from spflow.modules.ops.cat import Cat
from spflow.modules.products.product import Product
from spflow.modules.sums.sum import Sum
from spflow.utils.cache import Cache
from spflow.utils.domain import DataType
def _ensure_same_scope(a: Module, b: Module) -> None:
if a.scope != b.scope:
raise ShapeError(f"Scopes must match: {a.scope} vs {b.scope}.")
def _leaf_event_shape_ok(a: LeafModule, b: LeafModule) -> None:
if a.out_shape.features != b.out_shape.features:
raise ShapeError("Leaf features must match for inner product.")
if a.out_shape.repetitions != b.out_shape.repetitions:
raise ShapeError("Leaf repetitions must match for inner product.")
def _binomial_logpmf(k: Tensor, n: Tensor, p: Tensor) -> Tensor:
# k, n, p are broadcastable tensors (float64). Masking for k outside [0,n] handled by caller.
# log C(n,k) + k log p + (n-k) log(1-p)
logc = torch.lgamma(n + 1.0) - torch.lgamma(k + 1.0) - torch.lgamma(n - k + 1.0)
return logc + k * torch.log(p.clamp_min(1e-30)) + (n - k) * torch.log((1.0 - p).clamp_min(1e-30))
def _hypergeo_logpmf(k: Tensor, K: Tensor, N: Tensor, n: Tensor) -> Tensor:
# log [C(K,k) C(N-K, n-k) / C(N,n)]
log_c_K_k = torch.lgamma(K + 1.0) - torch.lgamma(k + 1.0) - torch.lgamma(K - k + 1.0)
NK = N - K
nk = n - k
log_c_NK_nk = torch.lgamma(NK + 1.0) - torch.lgamma(nk + 1.0) - torch.lgamma(NK - nk + 1.0)
log_c_N_n = torch.lgamma(N + 1.0) - torch.lgamma(n + 1.0) - torch.lgamma(N - n + 1.0)
return log_c_K_k + log_c_NK_nk - log_c_N_n
def _neg_binom_logpmf(k: Tensor, r: Tensor, p: Tensor) -> Tensor:
# Torch NegativeBinomial: number of successes k >= 0 before total_count=r failures, probs=p:
# pmf = C(k+r-1,k) p^k (1-p)^r
return (
torch.lgamma(k + r)
- torch.lgamma(k + 1.0)
- torch.lgamma(r)
+ k * torch.log(p.clamp_min(1e-30))
+ r * torch.log((1.0 - p).clamp_min(1e-30))
)
def _series_logsumexp(
*,
log_terms_fn: callable,
max_k: int,
tol: float,
device: torch.device,
dtype: torch.dtype = torch.float64,
) -> Tensor:
# Generic positive-series accumulator in log-space. log_terms_fn(k) -> log term tensor.
logS = torch.full((), float("-inf"), dtype=dtype, device=device)
for k in range(max_k + 1):
lt = log_terms_fn(k)
logS = torch.logaddexp(logS, lt)
if k >= 32:
# Relative contribution bound: exp(lt-logS) < tol
if torch.all((lt - logS) < torch.log(torch.tensor(tol, dtype=dtype, device=device))):
break
return logS
[docs]
def leaf_inner_product(a: Module, b: Module) -> Tensor:
"""Compute per-feature/channel inner products ∫ f_a(x) f_b(x) dx for leaves."""
_ensure_same_scope(a, b)
_leaf_event_shape_ok(a, b)
try:
from spflow.zoo.sos.signed_categorical import SignedCategorical as _SignedCategorical
except Exception: # pragma: no cover - optional zoo import
_SignedCategorical = None # type: ignore[assignment]
if isinstance(a, Normal) and isinstance(b, Normal):
mu1 = rearrange(a.loc.to(dtype=torch.float64), "f ci r -> f ci 1 r")
mu2 = rearrange(b.loc.to(dtype=torch.float64), "f co r -> f 1 co r")
s1 = rearrange(a.scale.to(dtype=torch.float64), "f ci r -> f ci 1 r")
s2 = rearrange(b.scale.to(dtype=torch.float64), "f co r -> f 1 co r")
var = s1.pow(2) + s2.pow(2)
log_coeff = -0.5 * torch.log(2.0 * torch.pi * var)
quad = -(mu1 - mu2).pow(2) / (2.0 * var)
return torch.exp(log_coeff + quad)
if isinstance(a, Bernoulli) and isinstance(b, Bernoulli):
p1 = rearrange(a.probs.to(dtype=torch.float64), "f ci r -> f ci 1 r")
p2 = rearrange(b.probs.to(dtype=torch.float64), "f co r -> f 1 co r")
return p1 * p2 + (1.0 - p1) * (1.0 - p2)
if isinstance(a, Categorical) and isinstance(b, Categorical):
if a.K != b.K:
raise ShapeError(f"Categorical K mismatch: {a.K} vs {b.K}.")
p1 = rearrange(a.probs.to(dtype=torch.float64), "f ci r k -> f ci 1 r k")
p2 = rearrange(b.probs.to(dtype=torch.float64), "f co r k -> f 1 co r k")
return torch.sum(p1 * p2, dim=-1)
if _SignedCategorical is not None:
if isinstance(a, _SignedCategorical) and isinstance(b, _SignedCategorical):
if a.K != b.K:
raise ShapeError(f"SignedCategorical K mismatch: {a.K} vs {b.K}.")
w1 = rearrange(a.weights.to(dtype=torch.float64), "f ci r k -> f ci 1 r k")
w2 = rearrange(b.weights.to(dtype=torch.float64), "f co r k -> f 1 co r k")
return torch.sum(w1 * w2, dim=-1)
if isinstance(a, _SignedCategorical) and isinstance(b, Categorical):
if a.K != b.K:
raise ShapeError(f"Categorical K mismatch: {a.K} vs {b.K}.")
w1 = rearrange(a.weights.to(dtype=torch.float64), "f ci r k -> f ci 1 r k")
p2 = rearrange(b.probs.to(dtype=torch.float64), "f co r k -> f 1 co r k")
return torch.sum(w1 * p2, dim=-1)
if isinstance(a, Categorical) and isinstance(b, _SignedCategorical):
if a.K != b.K:
raise ShapeError(f"Categorical K mismatch: {a.K} vs {b.K}.")
p1 = rearrange(a.probs.to(dtype=torch.float64), "f ci r k -> f ci 1 r k")
w2 = rearrange(b.weights.to(dtype=torch.float64), "f co r k -> f 1 co r k")
return torch.sum(p1 * w2, dim=-1)
if isinstance(a, Exponential) and isinstance(b, Exponential):
r1 = rearrange(a.rate.to(dtype=torch.float64), "f ci r -> f ci 1 r")
r2 = rearrange(b.rate.to(dtype=torch.float64), "f co r -> f 1 co r")
return (r1 * r2) / (r1 + r2).clamp_min(1e-30)
if isinstance(a, Laplace) and isinstance(b, Laplace):
mu1 = rearrange(a.loc.to(dtype=torch.float64), "f ci r -> f ci 1 r")
mu2 = rearrange(b.loc.to(dtype=torch.float64), "f co r -> f 1 co r")
b1 = rearrange(a.scale.to(dtype=torch.float64), "f ci r -> f ci 1 r").clamp_min(1e-30)
b2 = rearrange(b.scale.to(dtype=torch.float64), "f co r -> f 1 co r").clamp_min(1e-30)
d = torch.abs(mu1 - mu2)
exp1 = torch.exp(-d / b1)
exp2 = torch.exp(-d / b2)
term_tails = (exp1 + exp2) / (4.0 * (b1 + b2))
same = torch.isclose(b1, b2)
term_mid = (exp1 - exp2) / (4.0 * (b1 - b2))
term_mid_same = torch.exp(-d / b1) * d / (4.0 * b1.pow(2))
return torch.where(same, term_tails + term_mid_same, term_tails + term_mid)
if isinstance(a, LogNormal) and isinstance(b, LogNormal):
mu1 = rearrange(a.loc.to(dtype=torch.float64), "f ci r -> f ci 1 r")
mu2 = rearrange(b.loc.to(dtype=torch.float64), "f co r -> f 1 co r")
s1 = rearrange(a.scale.to(dtype=torch.float64), "f ci r -> f ci 1 r").clamp_min(1e-30)
s2 = rearrange(b.scale.to(dtype=torch.float64), "f co r -> f 1 co r").clamp_min(1e-30)
a1 = 1.0 / s1.pow(2)
a2 = 1.0 / s2.pow(2)
A = a1 + a2
D = (a1 * mu1 + a2 * mu2) - 1.0
E = -0.5 * (a1 * mu1.pow(2) + a2 * mu2.pow(2))
log_pref = -0.5 * torch.log(2.0 * torch.pi * (s1.pow(2) + s2.pow(2)))
return torch.exp(log_pref + E + (D.pow(2) / (2.0 * A)))
if isinstance(a, Poisson) and isinstance(b, Poisson):
l1 = rearrange(a.rate.to(dtype=torch.float64), "f ci r -> f ci 1 r").clamp_min(0.0)
l2 = rearrange(b.rate.to(dtype=torch.float64), "f co r -> f 1 co r").clamp_min(0.0)
z = 2.0 * torch.sqrt((l1 * l2).clamp_min(0.0))
i0 = getattr(torch, "i0", torch.special.i0)
return torch.exp(-(l1 + l2)) * i0(z)
if isinstance(a, Gamma) and isinstance(b, Gamma):
a1 = rearrange(a.concentration.to(dtype=torch.float64), "f ci r -> f ci 1 r").clamp_min(1e-30)
a2 = rearrange(b.concentration.to(dtype=torch.float64), "f co r -> f 1 co r").clamp_min(1e-30)
b1 = rearrange(a.rate.to(dtype=torch.float64), "f ci r -> f ci 1 r").clamp_min(1e-30)
b2 = rearrange(b.rate.to(dtype=torch.float64), "f co r -> f 1 co r").clamp_min(1e-30)
s = a1 + a2 - 1.0
if (s <= 0.0).any():
raise UnsupportedOperationError(
"Gamma inner product requires concentration_a + concentration_b > 1 for integrability."
)
log_ip = (
a1 * torch.log(b1)
+ a2 * torch.log(b2)
+ torch.lgamma(s)
- torch.lgamma(a1)
- torch.lgamma(a2)
- s * torch.log(b1 + b2)
)
return torch.exp(log_ip)
if isinstance(a, Uniform) and isinstance(b, Uniform):
a1 = rearrange(a.low.to(dtype=torch.float64), "f ci r -> f ci 1 r")
b1 = rearrange(a.high.to(dtype=torch.float64), "f ci r -> f ci 1 r")
a2 = rearrange(b.low.to(dtype=torch.float64), "f co r -> f 1 co r")
b2 = rearrange(b.high.to(dtype=torch.float64), "f co r -> f 1 co r")
len1 = (b1 - a1).clamp_min(1e-30)
len2 = (b2 - a2).clamp_min(1e-30)
left = torch.maximum(a1, a2)
right = torch.minimum(b1, b2)
overlap = (right - left).clamp_min(0.0)
return overlap / (len1 * len2)
if isinstance(a, Geometric) and isinstance(b, Geometric):
p1 = rearrange(a.probs.to(dtype=torch.float64), "f ci r -> f ci 1 r").clamp_min(0.0).clamp_max(1.0)
p2 = rearrange(b.probs.to(dtype=torch.float64), "f co r -> f 1 co r").clamp_min(0.0).clamp_max(1.0)
q1 = 1.0 - p1
q2 = 1.0 - p2
denom = 1.0 - (q1 * q2)
return (p1 * p2) / denom.clamp_min(1e-30)
if isinstance(a, Binomial) and isinstance(b, Binomial):
n1 = rearrange(a.total_count.to(dtype=torch.float64), "f ci r -> f ci 1 r")
n2 = rearrange(b.total_count.to(dtype=torch.float64), "f co r -> f 1 co r")
p1 = rearrange(a.probs.to(dtype=torch.float64), "f ci r -> f ci 1 r")
p2 = rearrange(b.probs.to(dtype=torch.float64), "f co r -> f 1 co r")
max_n = int(torch.max(torch.maximum(n1, n2)).item())
ks = rearrange(torch.arange(0, max_n + 1, dtype=torch.float64, device=p1.device), "k -> k 1 1 1 1")
n1b = rearrange(n1, "f ci co r -> 1 f ci co r")
n2b = rearrange(n2, "f ci co r -> 1 f ci co r")
lp1 = _binomial_logpmf(ks, n1b, rearrange(p1, "f ci co r -> 1 f ci co r"))
lp2 = _binomial_logpmf(ks, n2b, rearrange(p2, "f ci co r -> 1 f ci co r"))
mask = (ks <= n1b) & (ks <= n2b)
lsum = torch.logsumexp(torch.where(mask, lp1 + lp2, torch.full_like(lp1, float("-inf"))), dim=0)
return torch.exp(lsum)
if isinstance(a, Hypergeometric) and isinstance(b, Hypergeometric):
K1 = rearrange(a.K.to(dtype=torch.float64), "f ci r -> f ci 1 r")
N1 = rearrange(a.N.to(dtype=torch.float64), "f ci r -> f ci 1 r")
n1 = rearrange(a.n.to(dtype=torch.float64), "f ci r -> f ci 1 r")
K2 = rearrange(b.K.to(dtype=torch.float64), "f co r -> f 1 co r")
N2 = rearrange(b.N.to(dtype=torch.float64), "f co r -> f 1 co r")
n2 = rearrange(b.n.to(dtype=torch.float64), "f co r -> f 1 co r")
if not torch.allclose(N1, N2):
raise ShapeError("Hypergeometric inner product requires matching N (population size).")
N = N1
max_k = int(torch.max(torch.minimum(torch.minimum(n1, K1), torch.minimum(n2, K2))).item())
ks = rearrange(torch.arange(0, max_k + 1, dtype=torch.float64, device=N.device), "k -> k 1 1 1 1")
K1b = rearrange(K1, "f ci co r -> 1 f ci co r")
K2b = rearrange(K2, "f ci co r -> 1 f ci co r")
Nb = rearrange(N, "f ci co r -> 1 f ci co r")
n1b = rearrange(n1, "f ci co r -> 1 f ci co r")
n2b = rearrange(n2, "f ci co r -> 1 f ci co r")
lp1 = _hypergeo_logpmf(ks, K1b, Nb, n1b)
lp2 = _hypergeo_logpmf(ks, K2b, Nb, n2b)
min1 = rearrange(torch.maximum(torch.zeros_like(N), n1 + K1 - N), "f ci co r -> 1 f ci co r")
max1 = rearrange(torch.minimum(n1, K1), "f ci co r -> 1 f ci co r")
min2 = rearrange(torch.maximum(torch.zeros_like(N), n2 + K2 - N), "f ci co r -> 1 f ci co r")
max2 = rearrange(torch.minimum(n2, K2), "f ci co r -> 1 f ci co r")
mask = (ks >= min1) & (ks <= max1) & (ks >= min2) & (ks <= max2)
lsum = torch.logsumexp(torch.where(mask, lp1 + lp2, torch.full_like(lp1, float("-inf"))), dim=0)
return torch.exp(lsum)
if isinstance(a, NegativeBinomial) and isinstance(b, NegativeBinomial):
r1 = rearrange(a.total_count.to(dtype=torch.float64), "f ci r -> f ci 1 r")
r2 = rearrange(b.total_count.to(dtype=torch.float64), "f co r -> f 1 co r")
p1 = rearrange(a.probs.to(dtype=torch.float64), "f ci r -> f ci 1 r").clamp_min(1e-30).clamp_max(1.0)
p2 = rearrange(b.probs.to(dtype=torch.float64), "f co r -> f 1 co r").clamp_min(1e-30).clamp_max(1.0)
q = p1 * p2
def log_term(k: int) -> Tensor:
kk = torch.tensor(float(k), dtype=torch.float64, device=q.device)
# (r1)_k (r2)_k / (k!)^2 * (p1 p2)^k * (1-p1)^r1 (1-p2)^r2
lt = (
torch.lgamma(r1 + kk)
- torch.lgamma(r1)
+ torch.lgamma(r2 + kk)
- torch.lgamma(r2)
- 2.0 * torch.lgamma(kk + 1.0)
+ kk * torch.log(q.clamp_min(1e-30))
)
const = r1 * torch.log((1.0 - p1).clamp_min(1e-30)) + r2 * torch.log((1.0 - p2).clamp_min(1e-30))
return lt + const
logS = _series_logsumexp(log_terms_fn=log_term, max_k=4096, tol=1e-12, device=q.device)
return torch.exp(logS)
if isinstance(a, Histogram) and isinstance(b, Histogram):
# Univariate leaf: per-feature inner product is computed independently.
edges1 = a.bin_edges.to(dtype=torch.float64, device=a.device)
edges2 = b.bin_edges.to(dtype=torch.float64, device=b.device)
u_edges = torch.unique(torch.cat([edges1, edges2])).to(dtype=torch.float64)
u_edges, _ = torch.sort(u_edges)
seg_left = u_edges[:-1]
seg_right = u_edges[1:]
seg_len = (seg_right - seg_left).clamp_min(0.0)
mids = (seg_left + seg_right) / 2.0
widths1 = (edges1[1:] - edges1[:-1]).to(dtype=torch.float64)
widths2 = (edges2[1:] - edges2[:-1]).to(dtype=torch.float64)
dens1 = rearrange(
a.probs.to(dtype=torch.float64) / rearrange(widths1, "k -> 1 1 1 k"),
"f ci r b1 -> f ci 1 r b1",
)
dens2 = rearrange(
b.probs.to(dtype=torch.float64) / rearrange(widths2, "k -> 1 1 1 k"),
"f co r b2 -> f 1 co r b2",
)
idx1 = (torch.bucketize(mids, edges1, right=True) - 1).clamp(0, widths1.numel() - 1)
idx2 = (torch.bucketize(mids, edges2, right=True) - 1).clamp(0, widths2.numel() - 1)
in1 = (mids >= edges1[0]) & (mids < edges1[-1])
in2 = (mids >= edges2[0]) & (mids < edges2[-1])
mask = (in1 & in2).to(dtype=torch.float64)
d1 = dens1.index_select(-1, idx1).squeeze(-1) # (F,Ca,1,R,S)
d2 = dens2.index_select(-1, idx2).squeeze(-1) # (F,1,Cb,R,S)
prod = d1 * d2 # (F,Ca,Cb,R,S)
out = torch.sum(prod * rearrange(seg_len * mask, "s -> 1 1 1 1 s"), dim=-1)
return out
if isinstance(a, PiecewiseLinear) and isinstance(b, PiecewiseLinear):
if not a.is_initialized or not b.is_initialized:
raise UnsupportedOperationError(
"PiecewiseLinear inner product requires both leaves to be initialized."
)
if a.domains is None or b.domains is None:
raise UnsupportedOperationError("PiecewiseLinear inner product requires domains.")
# Only support continuous domains for now.
for dom in a.domains:
if dom.data_type != DataType.CONTINUOUS:
raise UnsupportedOperationError(
"PiecewiseLinear inner product currently supports continuous domains only."
)
dist_a = a.distribution()
dist_b = b.distribution()
F, Ca, Cb, R = (
a.out_shape.features,
a.out_shape.channels,
b.out_shape.channels,
a.out_shape.repetitions,
)
out = torch.empty((F, Ca, Cb, R), dtype=torch.float64, device=a.device)
def _get_knots(dist, r: int, leaf_idx: int, f: int) -> tuple[Tensor, Tensor]:
xs = dist.xs[r][leaf_idx][f][0]
ys = dist.ys[r][leaf_idx][f][0]
return xs.to(dtype=torch.float64), ys.to(dtype=torch.float64)
for r in range(R):
for ca in range(Ca):
for cb in range(Cb):
for f in range(F):
xa, ya = _get_knots(dist_a, r, ca, f)
xb, yb = _get_knots(dist_b, r, cb, f)
grid = torch.unique(torch.cat([xa, xb]))
grid, _ = torch.sort(grid)
if grid.numel() < 2:
out[f, ca, cb, r] = 0.0
continue
# Evaluate at grid points using the leaf's interpolation helper.
from spflow.modules.leaves.piecewise_linear import interp # local import
fa = interp(grid, xa, ya, extrapolate="constant")
fb = interp(grid, xb, yb, extrapolate="constant")
h = (grid[1:] - grid[:-1]).clamp_min(0.0)
f0, f1 = fa[:-1], fa[1:]
g0, g1 = fb[:-1], fb[1:]
integral = torch.sum(h / 6.0 * (2 * f0 * g0 + f0 * g1 + f1 * g0 + 2 * f1 * g1))
out[f, ca, cb, r] = integral
return out
if isinstance(a, CLTree) and isinstance(b, CLTree):
if a.K != b.K:
raise ShapeError(f"CLTree K mismatch: {a.K} vs {b.K}.")
if not torch.equal(a.parents, b.parents):
raise UnsupportedOperationError(
"CLTree inner product requires identical tree structure (parents)."
)
parents = a.parents.tolist()
root = parents.index(-1)
children: list[list[int]] = [[] for _ in range(a.out_shape.features)]
for child, parent in enumerate(parents):
if parent == -1:
continue
children[parent].append(child)
log_cpt_a = a.log_cpt.to(dtype=torch.float64)
log_cpt_b = b.log_cpt.to(dtype=torch.float64)
C1, C2 = a.out_shape.channels, b.out_shape.channels
R = a.out_shape.repetitions
K = a.K
F = a.out_shape.features
pa_root = torch.exp(log_cpt_a[root, :, :, :, 0])
pb_root = torch.exp(log_cpt_b[root, :, :, :, 0])
msg = torch.ones((F, C1, C2, R, K), dtype=torch.float64, device=log_cpt_a.device)
post_order = a.post_order.tolist()
for i in post_order:
p = parents[i]
if p == -1:
continue
prod_child = torch.ones((C1, C2, R, K), dtype=torch.float64, device=log_cpt_a.device)
for ch in children[i]:
prod_child = prod_child * msg[ch]
pa = torch.exp(log_cpt_a[i])
pb = torch.exp(log_cpt_b[i])
phi = rearrange(pa, "ca r i o -> ca 1 r i o") * rearrange(pb, "cb r i o -> 1 cb r i o")
msg_i = torch.einsum("abri,abrio->abro", prod_child, phi)
msg[i] = msg_i
prod_root = torch.ones((C1, C2, R, K), dtype=torch.float64, device=log_cpt_a.device)
for ch in children[root]:
prod_root = prod_root * msg[ch]
phi_root = rearrange(pa_root, "ca r i -> ca 1 r i") * rearrange(pb_root, "cb r i -> 1 cb r i")
z = torch.sum(phi_root * prod_root, dim=-1)
out = torch.ones((F, C1, C2, R), dtype=torch.float64, device=log_cpt_a.device)
out[0] = z
return out
raise UnsupportedOperationError(
f"Leaf inner product not implemented for {type(a).__name__} × {type(b).__name__}. "
"Supported: Normal, Bernoulli, Categorical, Exponential, Laplace, LogNormal, Poisson, Gamma, "
"Uniform, Geometric, Binomial, Hypergeometric, NegativeBinomial, Histogram, PiecewiseLinear, CLTree."
)
def _get_pair_memo(cache: Cache, *, memo_key: str) -> dict[tuple[int, int], Tensor]:
memo = cache.extras.get(memo_key)
if memo is None:
memo = {}
cache.extras[memo_key] = memo
return cast(dict[tuple[int, int], Tensor], memo)
def inner_product_matrix(
a: Module,
b: Module,
*,
cache: Cache | None = None,
signed_sum_types: Sequence[type[Module]] = (),
memo_key: str = "_inner_product_memo",
) -> Tensor:
if cache is not None:
memo = _get_pair_memo(cache, memo_key=memo_key)
key = (id(a), id(b))
cached = memo.get(key)
if cached is not None:
return cached
rev = memo.get((id(b), id(a)))
if rev is not None:
out = rearrange(rev, "f ci co r -> f co ci r").contiguous()
memo[key] = out
return out
_ensure_same_scope(a, b)
if a.out_shape.features != b.out_shape.features:
raise ShapeError(f"Feature mismatch: {a.out_shape.features} vs {b.out_shape.features}.")
if a.out_shape.repetitions != b.out_shape.repetitions:
raise ShapeError("Repetition mismatch for inner product.")
try:
from spflow.zoo.sos.signed_categorical import SignedCategorical as _SignedCategorical
except Exception: # pragma: no cover - optional zoo import
_SignedCategorical = None # type: ignore[assignment]
a_is_leaf_like = isinstance(a, LeafModule) or (
_SignedCategorical is not None and isinstance(a, _SignedCategorical)
)
b_is_leaf_like = isinstance(b, LeafModule) or (
_SignedCategorical is not None and isinstance(b, _SignedCategorical)
)
if a_is_leaf_like and b_is_leaf_like:
out = leaf_inner_product(a, b)
if cache is not None:
memo[(id(a), id(b))] = out
return out
if isinstance(a, Cat) and isinstance(b, Cat):
if a.dim != b.dim:
raise ShapeError("Cat dim mismatch for inner product.")
if a.dim == 1:
if len(a.inputs) != len(b.inputs):
raise ShapeError("Cat arity mismatch for inner product.")
parts = [
inner_product_matrix(
cast(Module, ai), cast(Module, bi), cache=cache, signed_sum_types=signed_sum_types
)
for ai, bi in zip(a.inputs, b.inputs)
]
out = torch.cat(parts, dim=0)
if cache is not None:
memo[(id(a), id(b))] = out
return out
if a.dim == 2:
F = a.out_shape.features
R = a.out_shape.repetitions
Ca = sum(cast(Module, ai).out_shape.channels for ai in a.inputs)
Cb = sum(cast(Module, bi).out_shape.channels for bi in b.inputs)
blocks: list[list[Tensor]] = []
for ai in a.inputs:
row: list[Tensor] = []
for bi in b.inputs:
row.append(
inner_product_matrix(
cast(Module, ai),
cast(Module, bi),
cache=cache,
signed_sum_types=signed_sum_types,
)
)
blocks.append(row)
out = torch.empty((F, Ca, Cb, R), dtype=torch.float64, device=blocks[0][0].device)
a_off = 0
for i, ai in enumerate(a.inputs):
ai_mod = cast(Module, ai)
a_ch = ai_mod.out_shape.channels
b_off = 0
for j, bi in enumerate(b.inputs):
bi_mod = cast(Module, bi)
b_ch = bi_mod.out_shape.channels
out[:, a_off : a_off + a_ch, b_off : b_off + b_ch, :] = blocks[i][j]
b_off += b_ch
a_off += a_ch
if cache is not None:
memo[(id(a), id(b))] = out
return out
raise UnsupportedOperationError(f"inner_product does not support Cat(dim={a.dim}).")
if isinstance(a, Product) and isinstance(b, Product):
child_k = inner_product_matrix(
cast(Module, a.inputs), cast(Module, b.inputs), cache=cache, signed_sum_types=signed_sum_types
)
out = torch.prod(child_k, dim=0, keepdim=True)
if cache is not None:
memo[(id(a), id(b))] = out
return out
sum_types = (Sum, *signed_sum_types)
if isinstance(a, sum_types) and isinstance(b, sum_types):
child_k = inner_product_matrix(
cast(Module, a.inputs), cast(Module, b.inputs), cache=cache, signed_sum_types=signed_sum_types
)
wa = a.weights.to(dtype=torch.float64) # type: ignore[attr-defined]
wb = b.weights.to(dtype=torch.float64) # type: ignore[attr-defined]
out = torch.einsum("fiar,fijr,fjbr->fabr", wa, child_k, wb)
if cache is not None:
memo[(id(a), id(b))] = out
return out
raise UnsupportedOperationError(
f"inner_product_matrix not implemented for {type(a).__name__} × {type(b).__name__}."
)
def log_self_inner_product_scalar(
module: Module,
*,
cache: Cache | None = None,
signed_sum_types: Sequence[type[Module]] = (),
memo_key: str = "_inner_product_memo",
) -> Tensor:
if tuple(module.out_shape) != (1, 1, 1):
raise ShapeError(f"Expected scalar output (1,1,1), got {tuple(module.out_shape)}.")
k = inner_product_matrix(
module, module, cache=cache, signed_sum_types=signed_sum_types, memo_key=memo_key
)
z = torch.clamp(k[0, 0, 0, 0], min=0.0)
return torch.log(z.clamp_min(1e-30))
def _get_triple_memo(cache: Cache, *, memo_key: str) -> dict[tuple[int, int, int], Tensor]:
memo = cache.extras.get(memo_key)
if memo is None:
memo = {}
cache.extras[memo_key] = memo
return cast(dict[tuple[int, int, int], Tensor], memo)
def triple_product_tensor(
a: Module,
b: Module,
c: Module,
*,
cache: Cache | None = None,
signed_sum_types: Sequence[type[Module]] = (),
memo_key: str = "_triple_product_memo",
) -> Tensor:
if cache is not None:
memo = _get_triple_memo(cache, memo_key=memo_key)
key = (id(a), id(b), id(c))
cached = memo.get(key)
if cached is not None:
return cached
swapped = memo.get((id(b), id(a), id(c)))
if swapped is not None:
out = rearrange(swapped, "f ci co cj r -> f co ci cj r").contiguous()
memo[key] = out
return out
_ensure_same_scope(a, b)
_ensure_same_scope(a, c)
if a.out_shape.features != b.out_shape.features or a.out_shape.features != c.out_shape.features:
raise ShapeError("Feature mismatch for triple product.")
if (
a.out_shape.repetitions != b.out_shape.repetitions
or a.out_shape.repetitions != c.out_shape.repetitions
):
raise ShapeError("Repetition mismatch for triple product.")
try:
from spflow.zoo.sos.signed_categorical import SignedCategorical as _SignedCategorical
except Exception: # pragma: no cover - optional zoo import
_SignedCategorical = None # type: ignore[assignment]
a_is_leaf_like = isinstance(a, LeafModule) or (
_SignedCategorical is not None and isinstance(a, _SignedCategorical)
)
b_is_leaf_like = isinstance(b, LeafModule) or (
_SignedCategorical is not None and isinstance(b, _SignedCategorical)
)
c_is_leaf_like = isinstance(c, LeafModule) or (
_SignedCategorical is not None and isinstance(c, _SignedCategorical)
)
if a_is_leaf_like and b_is_leaf_like and c_is_leaf_like:
# Handle a subset of leaf triples explicitly; otherwise reduce to finite sums/interval overlap where possible.
if (
_SignedCategorical is not None
and isinstance(a, (Categorical, _SignedCategorical))
and isinstance(b, (Categorical, _SignedCategorical))
and isinstance(c, (Categorical, _SignedCategorical))
):
Ks = (a.K, b.K, c.K) # type: ignore[attr-defined]
if len(set(Ks)) != 1:
raise ShapeError(f"Categorical K mismatch for triple product: {Ks}.")
def _cat_tensor(x: LeafModule) -> Tensor:
if isinstance(x, Categorical):
return x.probs.to(dtype=torch.float64)
return cast(_SignedCategorical, x).weights.to(dtype=torch.float64)
p1 = rearrange(_cat_tensor(a), "f ci r k -> f ci 1 1 r k")
p2 = rearrange(_cat_tensor(b), "f cj r k -> f 1 cj 1 r k")
p3 = rearrange(_cat_tensor(c), "f ck r k -> f 1 1 ck r k")
out = torch.sum(p1 * p2 * p3, dim=-1)
elif isinstance(a, Normal) and isinstance(b, Normal) and isinstance(c, Normal):
mu1 = rearrange(a.loc.to(dtype=torch.float64), "f ci r -> f ci 1 1 r")
mu2 = rearrange(b.loc.to(dtype=torch.float64), "f cj r -> f 1 cj 1 r")
mu3 = rearrange(c.loc.to(dtype=torch.float64), "f ck r -> f 1 1 ck r")
s1 = rearrange(a.scale.to(dtype=torch.float64).clamp_min(1e-30), "f ci r -> f ci 1 1 r")
s2 = rearrange(b.scale.to(dtype=torch.float64).clamp_min(1e-30), "f cj r -> f 1 cj 1 r")
s3 = rearrange(c.scale.to(dtype=torch.float64).clamp_min(1e-30), "f ck r -> f 1 1 ck r")
tau1 = 1.0 / s1.pow(2)
tau2 = 1.0 / s2.pow(2)
tau3 = 1.0 / s3.pow(2)
tau = tau1 + tau2 + tau3
m = (tau1 * mu1 + tau2 * mu2 + tau3 * mu3) / tau
quad = (tau1 * mu1.pow(2) + tau2 * mu2.pow(2) + tau3 * mu3.pow(2)) - tau * m.pow(2)
log_pref = -torch.log(torch.tensor(2.0 * torch.pi, dtype=torch.float64, device=mu1.device))
log_pref = log_pref - (torch.log(s1) + torch.log(s2) + torch.log(s3)) - 0.5 * torch.log(tau)
out = torch.exp(log_pref - 0.5 * quad)
elif isinstance(a, Bernoulli) and isinstance(b, Bernoulli) and isinstance(c, Bernoulli):
p1 = rearrange(a.probs.to(dtype=torch.float64), "f ci r -> f ci 1 1 r")
p2 = rearrange(b.probs.to(dtype=torch.float64), "f cj r -> f 1 cj 1 r")
p3 = rearrange(c.probs.to(dtype=torch.float64), "f ck r -> f 1 1 ck r")
q1 = 1.0 - p1
q2 = 1.0 - p2
q3 = 1.0 - p3
out = (q1 * q2 * q3) + (p1 * p2 * p3)
elif isinstance(a, Categorical) and isinstance(b, Categorical) and isinstance(c, Categorical):
if a.K != b.K or a.K != c.K:
raise ShapeError("Categorical K mismatch for triple product.")
p1 = rearrange(a.probs.to(dtype=torch.float64), "f ci r k -> f ci 1 1 r k")
p2 = rearrange(b.probs.to(dtype=torch.float64), "f cj r k -> f 1 cj 1 r k")
p3 = rearrange(c.probs.to(dtype=torch.float64), "f ck r k -> f 1 1 ck r k")
out = torch.sum(p1 * p2 * p3, dim=-1)
elif isinstance(a, Uniform) and isinstance(b, Uniform) and isinstance(c, Uniform):
a1 = rearrange(a.low.to(dtype=torch.float64), "f ci r -> f ci 1 1 r")
b1 = rearrange(a.high.to(dtype=torch.float64), "f ci r -> f ci 1 1 r")
a2 = rearrange(b.low.to(dtype=torch.float64), "f cj r -> f 1 cj 1 r")
b2 = rearrange(b.high.to(dtype=torch.float64), "f cj r -> f 1 cj 1 r")
a3 = rearrange(c.low.to(dtype=torch.float64), "f ck r -> f 1 1 ck r")
b3 = rearrange(c.high.to(dtype=torch.float64), "f ck r -> f 1 1 ck r")
len1 = (b1 - a1).clamp_min(1e-30)
len2 = (b2 - a2).clamp_min(1e-30)
len3 = (b3 - a3).clamp_min(1e-30)
left = torch.maximum(torch.maximum(a1, a2), a3)
right = torch.minimum(torch.minimum(b1, b2), b3)
overlap = (right - left).clamp_min(0.0)
out = overlap / (len1 * len2 * len3)
elif isinstance(a, Geometric) and isinstance(b, Geometric) and isinstance(c, Geometric):
p1 = (
rearrange(a.probs.to(dtype=torch.float64), "f ci r -> f ci 1 1 r")
.clamp_min(0.0)
.clamp_max(1.0)
)
p2 = (
rearrange(b.probs.to(dtype=torch.float64), "f cj r -> f 1 cj 1 r")
.clamp_min(0.0)
.clamp_max(1.0)
)
p3 = (
rearrange(c.probs.to(dtype=torch.float64), "f ck r -> f 1 1 ck r")
.clamp_min(0.0)
.clamp_max(1.0)
)
qprod = (1.0 - p1) * (1.0 - p2) * (1.0 - p3)
out = (p1 * p2 * p3) / (1.0 - qprod).clamp_min(1e-30)
elif isinstance(a, Binomial) and isinstance(b, Binomial) and isinstance(c, Binomial):
n1 = rearrange(a.total_count.to(dtype=torch.float64), "f ci r -> f ci 1 1 r")
n2 = rearrange(b.total_count.to(dtype=torch.float64), "f cj r -> f 1 cj 1 r")
n3 = rearrange(c.total_count.to(dtype=torch.float64), "f ck r -> f 1 1 ck r")
p1 = rearrange(a.probs.to(dtype=torch.float64), "f ci r -> f ci 1 1 r")
p2 = rearrange(b.probs.to(dtype=torch.float64), "f cj r -> f 1 cj 1 r")
p3 = rearrange(c.probs.to(dtype=torch.float64), "f ck r -> f 1 1 ck r")
max_n = int(torch.max(torch.maximum(torch.maximum(n1, n2), n3)).item())
ks = rearrange(
torch.arange(0, max_n + 1, dtype=torch.float64, device=p1.device), "k -> k 1 1 1 1 1"
)
n1b = rearrange(n1, "f ci cj ck r -> 1 f ci cj ck r")
n2b = rearrange(n2, "f ci cj ck r -> 1 f ci cj ck r")
n3b = rearrange(n3, "f ci cj ck r -> 1 f ci cj ck r")
lp1 = _binomial_logpmf(ks, n1b, rearrange(p1, "f ci cj ck r -> 1 f ci cj ck r"))
lp2 = _binomial_logpmf(ks, n2b, rearrange(p2, "f ci cj ck r -> 1 f ci cj ck r"))
lp3 = _binomial_logpmf(ks, n3b, rearrange(p3, "f ci cj ck r -> 1 f ci cj ck r"))
mask = (ks <= n1b) & (ks <= n2b) & (ks <= n3b)
lsum = torch.logsumexp(
torch.where(mask, lp1 + lp2 + lp3, torch.full_like(lp1, float("-inf"))), dim=0
)
out = torch.exp(lsum)
elif (
isinstance(a, Hypergeometric) and isinstance(b, Hypergeometric) and isinstance(c, Hypergeometric)
):
K1 = rearrange(a.K.to(dtype=torch.float64), "f ci r -> f ci 1 1 r")
N1 = rearrange(a.N.to(dtype=torch.float64), "f ci r -> f ci 1 1 r")
n1 = rearrange(a.n.to(dtype=torch.float64), "f ci r -> f ci 1 1 r")
K2 = rearrange(b.K.to(dtype=torch.float64), "f cj r -> f 1 cj 1 r")
N2 = rearrange(b.N.to(dtype=torch.float64), "f cj r -> f 1 cj 1 r")
n2 = rearrange(b.n.to(dtype=torch.float64), "f cj r -> f 1 cj 1 r")
K3 = rearrange(c.K.to(dtype=torch.float64), "f ck r -> f 1 1 ck r")
N3 = rearrange(c.N.to(dtype=torch.float64), "f ck r -> f 1 1 ck r")
n3 = rearrange(c.n.to(dtype=torch.float64), "f ck r -> f 1 1 ck r")
if not (torch.allclose(N1, N2) and torch.allclose(N1, N3)):
raise ShapeError("Hypergeometric triple product requires matching N.")
N = N1
max_k = int(
torch.max(
torch.minimum(
torch.minimum(n1, K1), torch.minimum(torch.minimum(n2, K2), torch.minimum(n3, K3))
)
).item()
)
ks = rearrange(
torch.arange(0, max_k + 1, dtype=torch.float64, device=N.device), "k -> k 1 1 1 1 1"
)
K1b = rearrange(K1, "f ci cj ck r -> 1 f ci cj ck r")
K2b = rearrange(K2, "f ci cj ck r -> 1 f ci cj ck r")
K3b = rearrange(K3, "f ci cj ck r -> 1 f ci cj ck r")
Nb = rearrange(N, "f ci cj ck r -> 1 f ci cj ck r")
n1b = rearrange(n1, "f ci cj ck r -> 1 f ci cj ck r")
n2b = rearrange(n2, "f ci cj ck r -> 1 f ci cj ck r")
n3b = rearrange(n3, "f ci cj ck r -> 1 f ci cj ck r")
lp1 = _hypergeo_logpmf(ks, K1b, Nb, n1b)
lp2 = _hypergeo_logpmf(ks, K2b, Nb, n2b)
lp3 = _hypergeo_logpmf(ks, K3b, Nb, n3b)
min1 = rearrange(
torch.maximum(torch.zeros_like(N), n1 + K1 - N), "f ci cj ck r -> 1 f ci cj ck r"
)
max1 = rearrange(torch.minimum(n1, K1), "f ci cj ck r -> 1 f ci cj ck r")
min2 = rearrange(
torch.maximum(torch.zeros_like(N), n2 + K2 - N), "f ci cj ck r -> 1 f ci cj ck r"
)
max2 = rearrange(torch.minimum(n2, K2), "f ci cj ck r -> 1 f ci cj ck r")
min3 = rearrange(
torch.maximum(torch.zeros_like(N), n3 + K3 - N), "f ci cj ck r -> 1 f ci cj ck r"
)
max3 = rearrange(torch.minimum(n3, K3), "f ci cj ck r -> 1 f ci cj ck r")
mask = (ks >= min1) & (ks <= max1) & (ks >= min2) & (ks <= max2) & (ks >= min3) & (ks <= max3)
lsum = torch.logsumexp(
torch.where(mask, lp1 + lp2 + lp3, torch.full_like(lp1, float("-inf"))), dim=0
)
out = torch.exp(lsum)
elif (
isinstance(a, NegativeBinomial)
and isinstance(b, NegativeBinomial)
and isinstance(c, NegativeBinomial)
):
r1 = rearrange(a.total_count.to(dtype=torch.float64), "f ci r -> f ci 1 1 r")
r2 = rearrange(b.total_count.to(dtype=torch.float64), "f cj r -> f 1 cj 1 r")
r3 = rearrange(c.total_count.to(dtype=torch.float64), "f ck r -> f 1 1 ck r")
p1 = (
rearrange(a.probs.to(dtype=torch.float64), "f ci r -> f ci 1 1 r")
.clamp_min(1e-30)
.clamp_max(1.0)
)
p2 = (
rearrange(b.probs.to(dtype=torch.float64), "f cj r -> f 1 cj 1 r")
.clamp_min(1e-30)
.clamp_max(1.0)
)
p3 = (
rearrange(c.probs.to(dtype=torch.float64), "f ck r -> f 1 1 ck r")
.clamp_min(1e-30)
.clamp_max(1.0)
)
q = p1 * p2 * p3
def log_term(k: int) -> Tensor:
kk = torch.tensor(float(k), dtype=torch.float64, device=q.device)
lt = (
torch.lgamma(r1 + kk)
- torch.lgamma(r1)
+ torch.lgamma(r2 + kk)
- torch.lgamma(r2)
+ torch.lgamma(r3 + kk)
- torch.lgamma(r3)
- 3.0 * torch.lgamma(kk + 1.0)
+ kk * torch.log(q.clamp_min(1e-30))
)
const = (
r1 * torch.log((1.0 - p1).clamp_min(1e-30))
+ r2 * torch.log((1.0 - p2).clamp_min(1e-30))
+ r3 * torch.log((1.0 - p3).clamp_min(1e-30))
)
return lt + const
logS = _series_logsumexp(log_terms_fn=log_term, max_k=4096, tol=1e-12, device=q.device)
out = torch.exp(logS)
elif isinstance(a, Histogram) and isinstance(b, Histogram) and isinstance(c, Histogram):
edges1 = a.bin_edges.to(dtype=torch.float64, device=a.device)
edges2 = b.bin_edges.to(dtype=torch.float64, device=b.device)
edges3 = c.bin_edges.to(dtype=torch.float64, device=c.device)
u_edges = torch.unique(torch.cat([edges1, edges2, edges3])).to(dtype=torch.float64)
u_edges, _ = torch.sort(u_edges)
seg_left = u_edges[:-1]
seg_right = u_edges[1:]
seg_len = (seg_right - seg_left).clamp_min(0.0)
mids = (seg_left + seg_right) / 2.0
widths1 = (edges1[1:] - edges1[:-1]).to(dtype=torch.float64)
widths2 = (edges2[1:] - edges2[:-1]).to(dtype=torch.float64)
widths3 = (edges3[1:] - edges3[:-1]).to(dtype=torch.float64)
dens1 = a.probs.to(dtype=torch.float64) / rearrange(widths1, "b1 -> 1 1 1 b1")
dens2 = b.probs.to(dtype=torch.float64) / rearrange(widths2, "b2 -> 1 1 1 b2")
dens3 = c.probs.to(dtype=torch.float64) / rearrange(widths3, "b3 -> 1 1 1 b3")
idx1 = (torch.bucketize(mids, edges1, right=True) - 1).clamp(0, widths1.numel() - 1)
idx2 = (torch.bucketize(mids, edges2, right=True) - 1).clamp(0, widths2.numel() - 1)
idx3 = (torch.bucketize(mids, edges3, right=True) - 1).clamp(0, widths3.numel() - 1)
in1 = (mids >= edges1[0]) & (mids < edges1[-1])
in2 = (mids >= edges2[0]) & (mids < edges2[-1])
in3 = (mids >= edges3[0]) & (mids < edges3[-1])
mask = (in1 & in2 & in3).to(dtype=torch.float64)
d1 = dens1.index_select(-1, idx1) # (F,Ca,R,S)
d2 = dens2.index_select(-1, idx2) # (F,Cb,R,S)
d3 = dens3.index_select(-1, idx3) # (F,Cc,R,S)
prod = (
rearrange(d1, "f ci r s -> f ci 1 1 r s")
* rearrange(d2, "f cj r s -> f 1 cj 1 r s")
* rearrange(d3, "f ck r s -> f 1 1 ck r s")
)
out = torch.sum(prod * rearrange(seg_len * mask, "s -> 1 1 1 1 1 s"), dim=-1)
elif (
isinstance(a, PiecewiseLinear)
and isinstance(b, PiecewiseLinear)
and isinstance(c, PiecewiseLinear)
):
if not (a.is_initialized and b.is_initialized and c.is_initialized):
raise UnsupportedOperationError(
"PiecewiseLinear triple product requires all leaves to be initialized."
)
if a.domains is None:
raise UnsupportedOperationError("PiecewiseLinear triple product requires domains.")
for dom in a.domains:
if dom.data_type != DataType.CONTINUOUS:
raise UnsupportedOperationError(
"PiecewiseLinear triple product currently supports continuous domains only."
)
dist_a = a.distribution()
dist_b = b.distribution()
dist_c = c.distribution()
F, Ca, Cb, Cc, R = (
a.out_shape.features,
a.out_shape.channels,
b.out_shape.channels,
c.out_shape.channels,
a.out_shape.repetitions,
)
out = torch.empty((F, Ca, Cb, Cc, R), dtype=torch.float64, device=a.device)
def _get_knots(dist, r: int, leaf_idx: int, f: int) -> tuple[Tensor, Tensor]:
xs = dist.xs[r][leaf_idx][f][0]
ys = dist.ys[r][leaf_idx][f][0]
return xs.to(dtype=torch.float64), ys.to(dtype=torch.float64)
from spflow.modules.leaves.piecewise_linear import interp # local import
u1 = 0.5 - 0.5 / torch.sqrt(torch.tensor(3.0, dtype=torch.float64, device=a.device))
u2 = 0.5 + 0.5 / torch.sqrt(torch.tensor(3.0, dtype=torch.float64, device=a.device))
for r in range(R):
for ca in range(Ca):
for cb in range(Cb):
for cc in range(Cc):
for f in range(F):
xa, ya = _get_knots(dist_a, r, ca, f)
xb, yb = _get_knots(dist_b, r, cb, f)
xc, yc = _get_knots(dist_c, r, cc, f)
grid = torch.unique(torch.cat([xa, xb, xc]))
grid, _ = torch.sort(grid)
if grid.numel() < 2:
out[f, ca, cb, cc, r] = 0.0
continue
fa = interp(grid, xa, ya, extrapolate="constant")
fb = interp(grid, xb, yb, extrapolate="constant")
fc = interp(grid, xc, yc, extrapolate="constant")
h = (grid[1:] - grid[:-1]).clamp_min(0.0)
a0, a1 = fa[:-1], fa[1:]
b0, b1 = fb[:-1], fb[1:]
c0, c1 = fc[:-1], fc[1:]
au1 = a0 + (a1 - a0) * u1
au2 = a0 + (a1 - a0) * u2
bu1 = b0 + (b1 - b0) * u1
bu2 = b0 + (b1 - b0) * u2
cu1 = c0 + (c1 - c0) * u1
cu2 = c0 + (c1 - c0) * u2
integral = torch.sum(h / 2.0 * (au1 * bu1 * cu1 + au2 * bu2 * cu2))
out[f, ca, cb, cc, r] = integral
else:
raise UnsupportedOperationError(
f"Leaf triple product not implemented for {type(a).__name__} × {type(b).__name__} × {type(c).__name__}."
)
if cache is not None:
memo[key] = out
return out
if isinstance(a, Cat) and isinstance(b, Cat) and isinstance(c, Cat):
if a.dim != b.dim or a.dim != c.dim:
raise ShapeError("Cat dim mismatch for triple product.")
if a.dim == 1:
if len(a.inputs) != len(b.inputs) or len(a.inputs) != len(c.inputs):
raise ShapeError("Cat arity mismatch for triple product.")
parts = [
triple_product_tensor(
cast(Module, ai),
cast(Module, bi),
cast(Module, ci),
cache=cache,
signed_sum_types=signed_sum_types,
memo_key=memo_key,
)
for ai, bi, ci in zip(a.inputs, b.inputs, c.inputs)
]
out = torch.cat(parts, dim=0)
if cache is not None:
memo[key] = out
return out
if a.dim == 2:
F = a.out_shape.features
R = a.out_shape.repetitions
Ca = sum(cast(Module, ai).out_shape.channels for ai in a.inputs)
Cb = sum(cast(Module, bi).out_shape.channels for bi in b.inputs)
Cc = sum(cast(Module, ci).out_shape.channels for ci in c.inputs)
out = torch.empty(
(F, Ca, Cb, Cc, R), dtype=torch.float64, device=cast(Module, a.inputs[0]).device
)
a_off = 0
for ai in a.inputs:
ai_mod = cast(Module, ai)
a_ch = ai_mod.out_shape.channels
b_off = 0
for bi in b.inputs:
bi_mod = cast(Module, bi)
b_ch = bi_mod.out_shape.channels
c_off = 0
for ci in c.inputs:
ci_mod = cast(Module, ci)
c_ch = ci_mod.out_shape.channels
out[
:, a_off : a_off + a_ch, b_off : b_off + b_ch, c_off : c_off + c_ch, :
] = triple_product_tensor(
ai_mod,
bi_mod,
ci_mod,
cache=cache,
signed_sum_types=signed_sum_types,
memo_key=memo_key,
)
c_off += c_ch
b_off += b_ch
a_off += a_ch
if cache is not None:
memo[key] = out
return out
raise UnsupportedOperationError(f"triple_product does not support Cat(dim={a.dim}).")
if isinstance(a, Product) and isinstance(b, Product) and isinstance(c, Product):
child_t = triple_product_tensor(
cast(Module, a.inputs),
cast(Module, b.inputs),
cast(Module, c.inputs),
cache=cache,
signed_sum_types=signed_sum_types,
memo_key=memo_key,
)
out = torch.prod(child_t, dim=0, keepdim=True)
if cache is not None:
memo[key] = out
return out
sum_types = (Sum, *signed_sum_types)
if isinstance(a, sum_types) and isinstance(b, sum_types) and isinstance(c, sum_types):
child_t = triple_product_tensor(
cast(Module, a.inputs),
cast(Module, b.inputs),
cast(Module, c.inputs),
cache=cache,
signed_sum_types=signed_sum_types,
memo_key=memo_key,
)
wa = a.weights.to(dtype=torch.float64) # type: ignore[attr-defined]
wb = b.weights.to(dtype=torch.float64) # type: ignore[attr-defined]
wc = c.weights.to(dtype=torch.float64) # type: ignore[attr-defined]
out = torch.einsum("fiar,fjbr,fkcr,fijkr->fabcr", wa, wb, wc, child_t)
if cache is not None:
memo[key] = out
return out
raise UnsupportedOperationError(
f"triple_product_tensor not implemented for {type(a).__name__} × {type(b).__name__} × {type(c).__name__}."
)
def triple_product_scalar(
a: Module,
b: Module,
c: Module,
*,
cache: Cache | None = None,
signed_sum_types: Sequence[type[Module]] = (),
memo_key: str = "_triple_product_memo",
) -> Tensor:
if tuple(a.out_shape) != (1, 1, 1) or tuple(b.out_shape) != (1, 1, 1) or tuple(c.out_shape) != (1, 1, 1):
raise ShapeError("triple_product_scalar expects all modules to have out_shape == (1,1,1).")
t = triple_product_tensor(a, b, c, cache=cache, signed_sum_types=signed_sum_types, memo_key=memo_key)
return t[0, 0, 0, 0, 0]