Source code for spflow.modules.leaves.binomial

from enum import Enum

import torch
from torch import Tensor, nn

from spflow.exceptions import InvalidParameterCombinationError
from spflow.meta.data import Scope
from spflow.modules.leaves.leaf import LeafModule
from spflow.utils.leaves import init_parameter, _handle_mle_edge_cases
from spflow.utils.projections import proj_bounded_to_real, proj_real_to_bounded
from spflow.utils.sampling_context import SIMPLE


class BinomialDifferentiableSamplingMethod(str, Enum):
    NORMAL = "normal"
    SIMPLE = "simple"


[docs] class Binomial(LeafModule): """Binomial distribution leaf module for probabilistic circuits. Implements univariate Binomial distributions as leaf nodes in probabilistic circuits. Supports parameter learning through maximum likelihood estimation and efficient inference through PyTorch's built-in distributions. The Binomial distribution models the number of successes in a fixed number of independent Bernoulli trials, with probability mass function:: P(X = k | n, p) = C(n, k) * p^k * (1-p)^(n-k) where n is the number of trials (fixed), p is the success probability (learnable, stored in logit-space for numerical stability), and k is the number of successes (0 ≤ k ≤ n). Attributes: p: Success probability parameter(s) in [0, 1] (BoundedParameter). n: Number of trials parameter(s), non-negative integers (fixed buffer). distribution: Underlying torch.distributions.Binomial. """
[docs] def __init__( self, scope: Scope, out_channels: int = 1, num_repetitions: int = 1, total_count: Tensor | None = None, probs: Tensor | None = None, logits: Tensor | None = None, parameter_fn: nn.Module = None, validate_args: bool | None = True, differentiable_sampling_method: BinomialDifferentiableSamplingMethod = BinomialDifferentiableSamplingMethod.NORMAL, ): """Initialize Binomial distribution leaf module. Args: scope: Scope object specifying the scope of the distribution. out_channels: Number of output channels (inferred from params if None). num_repetitions: Number of repetitions for the distribution. total_count: Number of trials tensor (required). probs: Success probability tensor (optional, randomly initialized if None). logits: Log-odds tensor for success probability. parameter_fn: Optional neural network for parameter generation. validate_args: Whether to enable torch.distributions argument validation. differentiable_sampling_method: Method for differentiable sampling (default: NORMAL). """ if total_count is None: raise InvalidParameterCombinationError("'n' parameter is required for Binomial distribution") if probs is not None and logits is not None: raise InvalidParameterCombinationError("Binomial accepts either probs or logits, not both.") param_source = logits if logits is not None else probs super().__init__( scope=scope, out_channels=out_channels, num_repetitions=num_repetitions, params=[param_source], parameter_fn=parameter_fn, validate_args=validate_args, ) init_fn = torch.randn if logits is not None else torch.rand init_value = init_parameter(param=param_source, event_shape=self.event_shape, init=init_fn) # Register total_count as a fixed buffer total_count = torch.broadcast_to(total_count, self.event_shape).clone() self.register_buffer("_total_count", total_count) logits_tensor = init_value if logits is not None else proj_bounded_to_real(init_value, lb=0.0, ub=1.0) self._logits = nn.Parameter(logits_tensor) self.differentiable_sampling_method = differentiable_sampling_method
@property def total_count(self) -> Tensor: """Returns the number of trials.""" return self._total_count @total_count.setter def total_count(self, total_count: Tensor): """Sets the number of trials. Args: total_count: Floating point representing the number of trials. """ self._total_count = total_count @property def probs(self) -> Tensor: """Success probability in natural space (read via inverse projection of logit_p).""" return proj_real_to_bounded(self._logits, lb=0.0, ub=1.0) @probs.setter def probs(self, value: Tensor) -> None: """Set success probability (stores as logit_p, no validation after init).""" value_tensor = torch.as_tensor(value, dtype=self._logits.dtype, device=self._logits.device) self._logits.data = proj_bounded_to_real(value_tensor, lb=0.0, ub=1.0) @property def logits(self) -> Tensor: """Logits for the success probability.""" return self._logits @logits.setter def logits(self, value: Tensor) -> None: value_tensor = torch.as_tensor(value, dtype=self._logits.dtype, device=self._logits.device) self._logits.data = value_tensor @property def _supported_value(self): """Fallback value for unsupported data.""" return 0.0 @property def _torch_distribution_class(self) -> type[torch.distributions.Binomial]: return torch.distributions.Binomial
[docs] def params(self) -> dict[str, Tensor]: """Returns distribution parameters.""" return {"total_count": self.total_count, "logits": self.logits}
def _compute_parameter_estimates( self, data: Tensor, weights: Tensor, bias_correction: bool ) -> dict[str, Tensor]: """Compute raw MLE estimates for binomial distribution (without broadcasting). Args: data: Input data tensor. weights: Weight tensor for each data point. bias_correction: Not used for Binomial (included for interface consistency). Returns: Dictionary with 'probs' estimate (shape: out_features). """ normalized_weights = weights / weights.sum(dim=0) n_total = normalized_weights.sum(dim=0) * self.total_count n_success = (normalized_weights * data).sum(0) probs_est = n_success / n_total # Handle edge cases (NaN, out of bounds) before broadcasting probs_est = _handle_mle_edge_cases(probs_est, lb=0.0, ub=1.0) return {"probs": probs_est} def _set_mle_parameters(self, params_dict: dict[str, Tensor]) -> None: """Set MLE-estimated parameters for Binomial distribution. Explicitly handles the parameter type: - probs: Property with setter, calls property setter which updates _logits Args: params_dict: Dictionary with 'probs' parameter value. """ self.probs = params_dict["probs"] # Uses property setter @property def _torch_distribution_class_with_differentiable_sampling( self, ) -> type[torch.distributions.Distribution]: if self.differentiable_sampling_method == BinomialDifferentiableSamplingMethod.NORMAL: return BinomialWithDifferentiableSamplingNormal elif self.differentiable_sampling_method == BinomialDifferentiableSamplingMethod.SIMPLE: return BinomialWithDifferentiableSamplingSIMPLE else: raise ValueError( f"Unsupported differentiable sampling method: {self.differentiable_sampling_method}" )
class BinomialWithDifferentiableSamplingNormal(torch.distributions.Binomial): """Binomial distribution with a differentiable rsample via Normal approximation. This provides a straight-through (STE) rounding estimator on top of a Normal approximation: `X ≈ round(clamp(μ + σ ε, [0, n]))`, where `ε ~ N(0, 1)`. """ has_rsample = True def sample(self, sample_shape: torch.Size = torch.Size()) -> Tensor: return self.rsample(sample_shape) def rsample(self, sample_shape: torch.Size = torch.Size()) -> Tensor: """Generate differentiable samples using a Normal approximation + STE rounding.""" sample_shape = torch.Size(sample_shape) extended_shape = self._extended_shape(sample_shape) probs = self.probs dtype = probs.dtype device = probs.device total_count = self.total_count.to(dtype=dtype, device=device) mean = total_count * probs var = total_count * probs * (1.0 - probs) std = torch.sqrt(torch.clamp(var, min=0.0)) eps = torch.randn(extended_shape, dtype=dtype, device=device) sample_cont = mean + std * eps sample_cont = torch.clamp(sample_cont, min=0.0) sample_cont = torch.minimum(sample_cont, total_count) sample_hard = torch.round(sample_cont) # Straight-through estimator: hard forward value, identity gradient. return (sample_hard - sample_cont).detach() + sample_cont class BinomialWithDifferentiableSamplingSIMPLE(torch.distributions.Binomial): """Binomial distribution with differentiable rsample via SIMPLE over counts {0..n}. This constructs logits for each possible count `k` and applies SIMPLE along the class dimension, returning the (relaxed) count as a weighted sum of class indices. """ has_rsample = True def sample(self, sample_shape: torch.Size = torch.Size()) -> Tensor: return self.rsample(sample_shape) def rsample(self, sample_shape: torch.Size = torch.Size()) -> Tensor: """Generate differentiable samples using SIMPLE.""" sample_shape = torch.Size(sample_shape) probs = self.probs dtype = probs.dtype device = probs.device total_count = self.total_count.to(device=device) max_total = int(torch.max(total_count).to(dtype=torch.int64).item()) if max_total < 0: raise ValueError(f"total_count must be >= 0, got max(total_count)={max_total}.") k = torch.arange(max_total + 1, device=device, dtype=dtype) # (K,) value = k.reshape(max_total + 1, *([1] * len(self.batch_shape))).expand( max_total + 1, *self.batch_shape ) # Avoid torch.distributions validation failures by computing log-probs # through a base Binomial distribution with validate_args disabled. base_dist = torch.distributions.Binomial( total_count=self.total_count, logits=self.logits, validate_args=False ) logits = base_dist.log_prob(value).movedim(0, -1) total_count_expanded = total_count.to(dtype=dtype).unsqueeze(-1) invalid = k > total_count_expanded logits = torch.where(invalid, torch.full_like(logits, float("-inf")), logits) if sample_shape: logits = logits.expand(*sample_shape, *logits.shape) samples_oh = SIMPLE(logits=logits, dim=-1, is_mpe=False) return (samples_oh * k).sum(dim=-1)