Source code for spflow.modules.leaves.negative_binomial

import torch
from torch import Tensor, nn

from spflow.exceptions import InvalidParameterCombinationError
from spflow.meta.data import Scope
from spflow.modules.leaves.leaf import LeafModule
from spflow.utils.leaves import init_parameter, _handle_mle_edge_cases
from spflow.utils.projections import proj_bounded_to_real, proj_real_to_bounded
from spflow.utils.sampling_context import SIMPLE


[docs] class NegativeBinomial(LeafModule): """Negative Binomial distribution leaf matching ``torch.distributions.NegativeBinomial``. In PyTorch, ``NegativeBinomial(total_count=r, probs=p)`` models the number of **successes** observed before ``r`` failures occur, where each trial succeeds with probability ``p``. This leaf uses that exact parameterization. Notes: - ``total_count`` (``r``) is fixed and cannot be learned. - ``probs`` (``p``) is learnable and stored in logit-space for numerical stability. Attributes: total_count: Fixed number of failures before stopping (buffer). probs: Success probability in ``[0, 1]`` (stored in logit-space). distribution: Underlying ``torch.distributions.NegativeBinomial``. """
[docs] def __init__( self, scope: Scope, out_channels: int = 1, num_repetitions: int = 1, total_count: Tensor | None = None, probs: Tensor | None = None, logits: Tensor | None = None, parameter_fn: nn.Module = None, validate_args: bool | None = True, ): """Initialize Negative Binomial distribution leaf module. Args: scope: Scope object specifying the scope of the distribution. out_channels: Number of output channels (inferred from params if None). num_repetitions: Number of repetitions for the distribution. total_count: Number of failures before stopping (required). probs: Success probability tensor (optional). logits: Logits of the success probability. parameter_fn: Optional neural network for parameter generation. validate_args: Whether to enable torch.distributions argument validation. """ if total_count is None: raise InvalidParameterCombinationError( "'n' parameter is required for NegativeBinomial distribution" ) if probs is not None and logits is not None: raise InvalidParameterCombinationError( "NegativeBinomial accepts either probs or logits, not both." ) param_source = logits if logits is not None else probs super().__init__( scope=scope, out_channels=out_channels, num_repetitions=num_repetitions, params=[param_source], parameter_fn=parameter_fn, validate_args=validate_args, ) init_fn = torch.randn if logits is not None else torch.rand init_value = init_parameter(param=param_source, event_shape=self.event_shape, init=init_fn) # Register n as a fixed buffer total_count = torch.broadcast_to(total_count, self.event_shape).clone() self.register_buffer("_total_count", total_count) logits_tensor = init_value if logits is not None else proj_bounded_to_real(init_value, lb=0.0, ub=1.0) self._logits = nn.Parameter(logits_tensor)
@property def total_count(self) -> Tensor: """Returns the fixed number of required successes.""" return self._total_count @total_count.setter def total_count(self, total_count: Tensor): """Sets the number of required successes. Args: total_count: Non-negative number of required successes. """ self._total_count = total_count @property def probs(self) -> Tensor: """Success probability in natural space (read via inverse projection of logit_p).""" return proj_real_to_bounded(self._logits, lb=0.0, ub=1.0) @probs.setter def probs(self, value: Tensor) -> None: """Set success probability (stores as logit_p, no validation after init).""" value_tensor = torch.as_tensor(value, dtype=self._logits.dtype, device=self._logits.device) self._logits.data = proj_bounded_to_real(value_tensor, lb=0.0, ub=1.0) @property def logits(self) -> Tensor: """Logits for success probability.""" return self._logits @logits.setter def logits(self, value: Tensor) -> None: value_tensor = torch.as_tensor(value, dtype=self._logits.dtype, device=self._logits.device) self._logits.data = value_tensor @property def _supported_value(self): """Fallback value for unsupported data.""" return 0 @property def _torch_distribution_class(self) -> type[torch.distributions.NegativeBinomial]: return torch.distributions.NegativeBinomial @property def _torch_distribution_class_with_differentiable_sampling( self, ) -> type[torch.distributions.Distribution]: return NegativeBinomialWithDifferentiableSamplingSIMPLE
[docs] def params(self) -> dict[str, Tensor]: """Returns distribution parameters.""" return {"total_count": self.total_count, "logits": self.logits}
def _compute_parameter_estimates( self, data: Tensor, weights: Tensor, bias_correction: bool ) -> dict[str, Tensor]: """Compute raw MLE estimates for negative binomial distribution (without broadcasting). Args: data: Scope-filtered data (failure counts). weights: Normalized sample weights. bias_correction: Whether to apply bias correction. Returns: Dictionary with 'probs' estimate (shape: out_features). """ n_total = weights.sum(dim=0) * self.total_count if bias_correction: n_total = n_total - 1 n_success = (weights * data).sum(0) p_est = 1 - n_total / (n_success + n_total) # Handle edge cases (NaN, zero, or near-zero/one probs) before broadcasting p_est = _handle_mle_edge_cases(p_est, lb=0.0, ub=1.0) return {"probs": p_est} def _set_mle_parameters(self, params_dict: dict[str, Tensor]) -> None: """Set MLE-estimated parameters for NegativeBinomial distribution. Explicitly handles the parameter type: - probs: Property with setter, calls property setter which updates _logits Note: total_count (n) is fixed and not updated during MLE. Args: params_dict: Dictionary with 'probs' parameter value. """ self.probs = params_dict["probs"] # Uses property setter
class NegativeBinomialWithDifferentiableSamplingSIMPLE(torch.distributions.NegativeBinomial): """NegativeBinomial distribution with differentiable rsample via truncated SIMPLE. Notes: The NegativeBinomial distribution has infinite support over {0, 1, 2, ...}. This implementation uses a truncated support [0..Kmax] where Kmax is inferred from the current parameters and capped to keep computation bounded. """ has_rsample = True _MAX_SUPPORT: int = 2048 def sample(self, sample_shape: torch.Size = torch.Size()) -> Tensor: return self.rsample(sample_shape) def rsample(self, sample_shape: torch.Size = torch.Size()) -> Tensor: sample_shape = torch.Size(sample_shape) probs = self.probs total_count = self.total_count.to(device=probs.device, dtype=probs.dtype) dtype = probs.dtype device = probs.device denom = torch.clamp(1.0 - probs, min=torch.finfo(dtype).eps) mean = total_count * probs / denom var = total_count * probs / (denom * denom) std = torch.sqrt(torch.clamp(var, min=0.0)) max_k = torch.ceil((mean + 10.0 * std + 10.0).max()).to(dtype=torch.int64) max_k_int = int(torch.clamp(max_k, min=0, max=self._MAX_SUPPORT).item()) k = torch.arange(max_k_int + 1, device=device, dtype=dtype) # (K,) value = k.reshape(max_k_int + 1, *([1] * len(self.batch_shape))).expand( max_k_int + 1, *self.batch_shape ) base_dist = torch.distributions.NegativeBinomial( total_count=self.total_count, logits=self.logits, validate_args=False ) logits = base_dist.log_prob(value).movedim(0, -1) if sample_shape: logits = logits.expand(*sample_shape, *logits.shape) samples_oh = SIMPLE(logits=logits, dim=-1, is_mpe=False) return (samples_oh * k).sum(dim=-1)