Source code for spflow.modules.leaves.poisson

import torch
from torch import Tensor, nn

from spflow.modules.leaves.leaf import LeafModule
from spflow.utils.leaves import init_parameter, _handle_mle_edge_cases
from spflow.utils.sampling_context import SIMPLE


[docs] class Poisson(LeafModule): """Poisson distribution leaf for modeling event counts. Parameterized by rate λ > 0 (stored in log-space for numerical stability). Attributes: rate: Rate parameter λ (stored as log_rate internally). distribution: Underlying torch.distributions.Poisson. """
[docs] def __init__( self, scope, out_channels: int = 1, num_repetitions: int = 1, parameter_fn: nn.Module = None, validate_args: bool | None = True, rate: Tensor = None, ): """Initialize Poisson leaf. Args: scope: Variable scope (Scope, int, or list[int]). out_channels: Number of output channels (inferred from params if None). num_repetitions: Number of repetitions (for 3D event shapes). parameter_fn: Optional neural network for parameter generation. validate_args: Whether to enable torch.distributions argument validation. rate: Rate parameter λ > 0. """ super().__init__( scope=scope, out_channels=out_channels, num_repetitions=num_repetitions, params=[rate], parameter_fn=parameter_fn, validate_args=validate_args, ) rate = init_parameter(param=rate, event_shape=self._event_shape, init=torch.ones) self.log_rate = nn.Parameter(torch.log(rate))
@property def rate(self) -> Tensor: """Rate parameter in natural space (read via exp of log_rate).""" return torch.exp(self.log_rate) @rate.setter def rate(self, value: Tensor) -> None: """Set rate parameter (stores as log_rate, no validation after init).""" self.log_rate.data = torch.log( torch.as_tensor(value, dtype=self.log_rate.dtype, device=self.log_rate.device) ) @property def _supported_value(self): """Fallback value for unsupported data.""" return 0 @property def _torch_distribution_class(self) -> type[torch.distributions.Poisson]: return torch.distributions.Poisson @property def _torch_distribution_class_with_differentiable_sampling( self, ) -> type[torch.distributions.Distribution]: return PoissonWithDifferentiableSamplingSIMPLE
[docs] def params(self) -> dict[str, Tensor]: """Returns distribution parameters.""" return {"rate": self.rate}
def _compute_parameter_estimates( self, data: Tensor, weights: Tensor, bias_correction: bool ) -> dict[str, Tensor]: """Compute raw MLE estimates for Poisson distribution (without broadcasting). For Poisson distribution, the MLE is simply the weighted mean of the data. Args: data: Input data tensor. weights: Weight tensor for each data point. bias_correction: Not used for Poisson. Returns: Dictionary with 'rate' estimate (shape: out_features). """ n_total = weights.sum(dim=0) rate_est = (weights * data).sum(dim=0) / n_total # Handle edge cases (NaN, zero, or near-zero rate) before broadcasting rate_est = _handle_mle_edge_cases(rate_est, lb=0.0) return {"rate": rate_est} def _set_mle_parameters(self, params_dict: dict[str, Tensor]) -> None: """Set MLE-estimated parameters for Poisson distribution. Explicitly handles the parameter type: - rate: Property with setter, calls property setter which updates log_rate Args: params_dict: Dictionary with 'rate' parameter value. """ self.rate = params_dict["rate"] # Uses property setter
class PoissonWithDifferentiableSamplingSIMPLE(torch.distributions.Poisson): """Poisson distribution with differentiable rsample via truncated SIMPLE. Notes: The Poisson distribution has infinite support over {0, 1, 2, ...}. This implementation uses a truncated support [0..Kmax] where Kmax is inferred from the current rate and capped to keep computation bounded. """ has_rsample = True _MAX_SUPPORT: int = 2048 def sample(self, sample_shape: torch.Size = torch.Size()) -> Tensor: return self.rsample(sample_shape) def rsample(self, sample_shape: torch.Size = torch.Size()) -> Tensor: sample_shape = torch.Size(sample_shape) rate = self.rate dtype = rate.dtype device = rate.device std = torch.sqrt(torch.clamp(rate, min=0.0)) max_k = torch.ceil((rate + 10.0 * std + 10.0).max()).to(dtype=torch.int64) max_k_int = int(torch.clamp(max_k, min=0, max=self._MAX_SUPPORT).item()) k = torch.arange(max_k_int + 1, device=device, dtype=dtype) # (K,) value = k.reshape(max_k_int + 1, *([1] * len(self.batch_shape))).expand( max_k_int + 1, *self.batch_shape ) base_dist = torch.distributions.Poisson(rate=rate, validate_args=False) logits = base_dist.log_prob(value).movedim(0, -1) if sample_shape: logits = logits.expand(*sample_shape, *logits.shape) samples_oh = SIMPLE(logits=logits, dim=-1, is_mpe=False) return (samples_oh * k).sum(dim=-1)