Source code for spflow.learn.build_socs

"""SOCS structure builder utilities.

SOCS (Σ2cmp) requires a set of compatible component circuits (same structure).
This helper provides a minimal, SPFlow-native way to construct such components
by cloning a template circuit and optionally converting `Sum` nodes into
`SignedSum` nodes with perturbed (possibly negative) weights.
"""

from __future__ import annotations

import copy

import torch

from spflow.exceptions import InvalidParameterError
from spflow.modules.leaves.categorical import Categorical
from spflow.modules.module import Module
from spflow.modules.sos.socs import SOCS
from spflow.modules.sums.signed_sum import SignedSum
from spflow.modules.sums.sum import Sum
from spflow.utils.compatibility import check_compatible_components


[docs] def build_socs( template: Module, *, num_components: int, signed: bool = True, noise_scale: float = 0.05, flip_prob: float = 0.5, seed: int | None = None, ) -> SOCS: """Build a SOCS model from a compatible component template. Args: template: A SPFlow module representing a (typically scalar-output) circuit. This circuit is deep-copied `num_components` times to ensure all components share the same structure. num_components: Number of components r. signed: If True, convert all `Sum` nodes in each clone to `SignedSum` nodes with perturbed weights (allowing negative weights). noise_scale: Standard deviation of additive Gaussian noise applied to copied weights when `signed=True`. flip_prob: Probability of flipping the sign of each weight entry when `signed=True`. Must be in [0, 1]. seed: Optional random seed used for weight perturbations. Returns: A `SOCS` module with `num_components` compatible components. """ if num_components < 1: raise InvalidParameterError("num_components must be >= 1.") if noise_scale < 0.0: raise InvalidParameterError("noise_scale must be >= 0.") if not (0.0 <= flip_prob <= 1.0): raise InvalidParameterError("flip_prob must be in [0, 1].") gen = None if seed is not None: gen = torch.Generator(device=template.device) gen.manual_seed(int(seed)) def _convert_sum(node: Sum) -> SignedSum: w = node.weights.detach().clone() if gen is None: flip = (torch.rand_like(w) < flip_prob).to(dtype=w.dtype) else: flip = (torch.rand(w.shape, dtype=w.dtype, device=w.device, generator=gen) < flip_prob).to( dtype=w.dtype ) sign = 1.0 - 2.0 * flip # {+1,-1} w = w * sign if noise_scale > 0.0: if gen is None: w = w + noise_scale * torch.randn_like(w) else: w = w + noise_scale * torch.randn(w.shape, dtype=w.dtype, device=w.device, generator=gen) return SignedSum( inputs=node.inputs, out_channels=node.out_shape.channels, num_repetitions=node.out_shape.repetitions, weights=w, ) def _transform_in_place(root: torch.nn.Module) -> None: for name, child in list(root.named_children()): if isinstance(child, Sum) and signed: root._modules[name] = _convert_sum(child) child = root._modules[name] _transform_in_place(child) components: list[Module] = [] for _i in range(num_components): comp = copy.deepcopy(template) if isinstance(comp, Sum) and signed: comp = _convert_sum(comp) _transform_in_place(comp) components.append(comp) check_compatible_components(components) return SOCS(components)
[docs] def build_abs_weight_proposal(component: Module, *, eps: float = 1e-8) -> Module: """Build a monotone proposal q(x) from a (possibly signed) component. Replaces each `SignedSum` with a standard `Sum` whose weights are proportional to `abs(weights)`, ensuring q is non-negative and normalized at each sum node. Args: component: Component circuit to convert. eps: Small additive constant to avoid all-zero abs weights. Returns: A new `Module` that supports `.sample()` and `.log_likelihood()` and can be used as an independence proposal. """ if eps <= 0.0: raise InvalidParameterError("eps must be > 0.") prop = copy.deepcopy(component) def _convert_signed(node: SignedSum) -> Sum: w = node.weights.detach() w = torch.abs(w) + w.new_tensor(float(eps)) w = w / w.sum(dim=1, keepdim=True).clamp_min(1e-12) return Sum( inputs=node.inputs, out_channels=node.out_shape.channels, num_repetitions=node.out_shape.repetitions, weights=w, ) def _is_signed_categorical(node: torch.nn.Module) -> bool: return ( node.__class__.__name__ == "SignedCategorical" and hasattr(node, "weights") and hasattr(node, "K") ) def _convert_signed_categorical(node: torch.nn.Module) -> Categorical: probs = torch.abs(node.weights.detach()) + node.weights.new_tensor(float(eps)) # type: ignore[attr-defined] probs = probs / probs.sum(dim=-1, keepdim=True).clamp_min(1e-12) return Categorical( scope=node.scope, # type: ignore[attr-defined] out_channels=node.out_shape.channels, # type: ignore[attr-defined] num_repetitions=node.out_shape.repetitions, # type: ignore[attr-defined] K=node.K, # type: ignore[attr-defined] probs=probs, ) if isinstance(prop, SignedSum): prop = _convert_signed(prop) if _is_signed_categorical(prop): prop = _convert_signed_categorical(prop) def _transform(root: torch.nn.Module) -> None: for name, child in list(root.named_children()): if isinstance(child, SignedSum): root._modules[name] = _convert_signed(child) child = root._modules[name] elif _is_signed_categorical(child): root._modules[name] = _convert_signed_categorical(child) child = root._modules[name] _transform(child) _transform(prop) return prop
def build_complex_socs(real: Module, imag: Module) -> SOCS: """Build a SOCS model equivalent to a complex squared circuit |c|^2. This implements the paper-aligned reduction: c(x) = a(x) + i b(x) => |c(x)|^2 = a(x)^2 + b(x)^2 by constructing a SOCS with two components `[a, b]`. This avoids introducing complex-valued parameters/semirings in SPFlow while still matching the squared-magnitude semantics used by complex SOS models. Args: real: Circuit computing a(x). imag: Circuit computing b(x). Returns: A `SOCS` module with two components. """ if real.scope != imag.scope: raise InvalidParameterError( "build_complex_socs requires real and imag circuits to have identical scope." ) if tuple(real.out_shape) != tuple(imag.out_shape): raise InvalidParameterError( "build_complex_socs requires real and imag circuits to have identical out_shape; " f"got {tuple(real.out_shape)} vs {tuple(imag.out_shape)}." ) return SOCS([real, imag])