Source code for spflow.modules.leaves.gamma

import torch
from torch import Tensor, nn

from spflow.modules.leaves.leaf import LeafModule
from spflow.utils.leaves import validate_all_or_none, init_parameter, _handle_mle_edge_cases


[docs] class Gamma(LeafModule): """Gamma distribution leaf for modeling positive-valued continuous data. Parameterized by shape α > 0 and rate β > 0 (both stored in log-space for numerical stability). Attributes: alpha: Shape parameter α (accessed via property, stored as log_alpha). beta: Rate parameter β (accessed via property, stored as log_beta). distribution: Underlying torch.distributions.Gamma. """
[docs] def __init__( self, scope, out_channels: int = None, num_repetitions: int = 1, parameter_fn: nn.Module = None, validate_args: bool | None = True, concentration: Tensor = None, rate: Tensor = None, ): """Initialize Gamma distribution leaf. Args: scope: Variable scope (Scope, int, or list[int]). out_channels: Number of output channels (inferred from params if None). num_repetitions: Number of repetitions (for 3D event shapes). parameter_fn: Optional neural network for parameter generation. validate_args: Whether to enable torch.distributions argument validation. concentration: Shape parameter α > 0. rate: Rate parameter β > 0. """ super().__init__( scope=scope, out_channels=out_channels, num_repetitions=num_repetitions, params=[concentration, rate], parameter_fn=parameter_fn, validate_args=validate_args, ) validate_all_or_none(concentration=concentration, rate=rate) # Initialize parameters in well-behaved range [0.5, 5.0] to avoid # extreme values that cause MLE instability def init_gamma_param(shape): return torch.rand(shape) * 4.5 + 0.5 concentration = init_parameter( param=concentration, event_shape=self._event_shape, init=init_gamma_param ) rate = init_parameter(param=rate, event_shape=self._event_shape, init=init_gamma_param) self.log_concentration = nn.Parameter(torch.log(concentration)) self.log_rate = nn.Parameter(torch.log(rate))
@property def concentration(self) -> Tensor: """Shape parameter in natural space (read via exp of log_alpha).""" return torch.exp(self.log_concentration) @concentration.setter def concentration(self, value: Tensor) -> None: """Set shape parameter (stores as log_alpha, no validation after init).""" self.log_concentration.data = torch.log( torch.as_tensor(value, dtype=self.log_concentration.dtype, device=self.log_concentration.device) ) @property def rate(self) -> Tensor: """Rate parameter in natural space (read via exp of log_beta).""" return torch.exp(self.log_rate) @rate.setter def rate(self, value: Tensor) -> None: """Set rate parameter (stores as log_beta, no validation after init).""" self.log_rate.data = torch.log( torch.as_tensor(value, dtype=self.log_rate.dtype, device=self.log_rate.device) ) @property def _supported_value(self): """Fallback value for unsupported data.""" return 1.0 @property def _torch_distribution_class(self) -> type[torch.distributions.Gamma]: return torch.distributions.Gamma
[docs] def conditional_distribution(self, evidence: Tensor) -> torch.distributions.Gamma: # Pass evidence to parameter network to get parameters params = self.parameter_fn(evidence) # Apply exponential to ensure positive parameters and construct torch Gamma distribution return torch.distributions.Gamma( concentration=torch.exp(params["concentration"]), rate=torch.exp(params["rate"]), validate_args=self._validate_args, )
[docs] def params(self) -> dict[str, Tensor]: """Returns distribution parameters.""" return {"concentration": self.concentration, "rate": self.rate}
def _compute_parameter_estimates( self, data: Tensor, weights: Tensor, bias_correction: bool ) -> dict[str, Tensor]: """Compute raw MLE estimates for Gamma distribution (without broadcasting). Uses moment-matching equations to estimate parameters with optional bias correction. Args: data: Input data tensor. weights: Weight tensor for each data point. bias_correction: Whether to apply bias correction to parameter estimates. Returns: Dictionary with 'concentration' and 'rate' estimates (shape: out_features). """ n_total = weights.sum(dim=0) data_log = data.log() mean_xlnx = (weights * data_log * data).sum(dim=0) / n_total mean_x = (weights * data).sum(dim=0) / n_total mean_ln_x = (weights * data_log).sum(dim=0) / n_total theta_est = mean_xlnx - mean_x * mean_ln_x concentration_est = mean_x / theta_est rate_est = 1 / theta_est if bias_correction: concentration_est = concentration_est - 1 / n_total * ( 3 * concentration_est - 2 / 3 * (concentration_est / (1 + concentration_est)) - 4 / 5 * (concentration_est / (1 + concentration_est) ** 2) ) rate_est = rate_est * ((n_total - 1) / n_total) # Handle edge cases before broadcasting concentration_est = _handle_mle_edge_cases(concentration_est, lb=0.0) rate_est = _handle_mle_edge_cases(rate_est, lb=0.0) return {"concentration": concentration_est, "rate": rate_est} def _set_mle_parameters(self, params_dict: dict[str, Tensor]) -> None: """Set MLE-estimated parameters for Gamma distribution. Explicitly handles the two parameter types: - concentration: Property with setter, calls property setter which updates log_concentration - rate: Property with setter, calls property setter which updates log_rate Args: params_dict: Dictionary with 'concentration' and 'rate' parameter values. """ self.concentration = params_dict["concentration"] # Uses property setter self.rate = params_dict["rate"] # Uses property setter