import torch
from torch import Tensor, nn
from spflow.modules.leaves.leaf import LeafModule
from spflow.utils.leaves import validate_all_or_none, init_parameter, _handle_mle_edge_cases
[docs]
class Gamma(LeafModule):
"""Gamma distribution leaf for modeling positive-valued continuous data.
Parameterized by shape α > 0 and rate β > 0 (both stored in log-space for numerical stability).
Attributes:
alpha: Shape parameter α (accessed via property, stored as log_alpha).
beta: Rate parameter β (accessed via property, stored as log_beta).
distribution: Underlying torch.distributions.Gamma.
"""
[docs]
def __init__(
self,
scope,
out_channels: int = None,
num_repetitions: int = 1,
parameter_fn: nn.Module = None,
validate_args: bool | None = True,
concentration: Tensor = None,
rate: Tensor = None,
):
"""Initialize Gamma distribution leaf.
Args:
scope: Variable scope (Scope, int, or list[int]).
out_channels: Number of output channels (inferred from params if None).
num_repetitions: Number of repetitions (for 3D event shapes).
parameter_fn: Optional neural network for parameter generation.
validate_args: Whether to enable torch.distributions argument validation.
concentration: Shape parameter α > 0.
rate: Rate parameter β > 0.
"""
super().__init__(
scope=scope,
out_channels=out_channels,
num_repetitions=num_repetitions,
params=[concentration, rate],
parameter_fn=parameter_fn,
validate_args=validate_args,
)
validate_all_or_none(concentration=concentration, rate=rate)
# Initialize parameters in well-behaved range [0.5, 5.0] to avoid
# extreme values that cause MLE instability
def init_gamma_param(shape):
return torch.rand(shape) * 4.5 + 0.5
concentration = init_parameter(
param=concentration, event_shape=self._event_shape, init=init_gamma_param
)
rate = init_parameter(param=rate, event_shape=self._event_shape, init=init_gamma_param)
self.log_concentration = nn.Parameter(torch.log(concentration))
self.log_rate = nn.Parameter(torch.log(rate))
@property
def concentration(self) -> Tensor:
"""Shape parameter in natural space (read via exp of log_alpha)."""
return torch.exp(self.log_concentration)
@concentration.setter
def concentration(self, value: Tensor) -> None:
"""Set shape parameter (stores as log_alpha, no validation after init)."""
self.log_concentration.data = torch.log(
torch.as_tensor(value, dtype=self.log_concentration.dtype, device=self.log_concentration.device)
)
@property
def rate(self) -> Tensor:
"""Rate parameter in natural space (read via exp of log_beta)."""
return torch.exp(self.log_rate)
@rate.setter
def rate(self, value: Tensor) -> None:
"""Set rate parameter (stores as log_beta, no validation after init)."""
self.log_rate.data = torch.log(
torch.as_tensor(value, dtype=self.log_rate.dtype, device=self.log_rate.device)
)
@property
def _supported_value(self):
"""Fallback value for unsupported data."""
return 1.0
@property
def _torch_distribution_class(self) -> type[torch.distributions.Gamma]:
return torch.distributions.Gamma
[docs]
def conditional_distribution(self, evidence: Tensor) -> torch.distributions.Gamma:
# Pass evidence to parameter network to get parameters
params = self.parameter_fn(evidence)
# Apply exponential to ensure positive parameters and construct torch Gamma distribution
return torch.distributions.Gamma(
concentration=torch.exp(params["concentration"]),
rate=torch.exp(params["rate"]),
validate_args=self._validate_args,
)
[docs]
def params(self) -> dict[str, Tensor]:
"""Returns distribution parameters."""
return {"concentration": self.concentration, "rate": self.rate}
def _compute_parameter_estimates(
self, data: Tensor, weights: Tensor, bias_correction: bool
) -> dict[str, Tensor]:
"""Compute raw MLE estimates for Gamma distribution (without broadcasting).
Uses moment-matching equations to estimate parameters with optional bias correction.
Args:
data: Input data tensor.
weights: Weight tensor for each data point.
bias_correction: Whether to apply bias correction to parameter estimates.
Returns:
Dictionary with 'concentration' and 'rate' estimates (shape: out_features).
"""
n_total = weights.sum(dim=0)
data_log = data.log()
mean_xlnx = (weights * data_log * data).sum(dim=0) / n_total
mean_x = (weights * data).sum(dim=0) / n_total
mean_ln_x = (weights * data_log).sum(dim=0) / n_total
theta_est = mean_xlnx - mean_x * mean_ln_x
concentration_est = mean_x / theta_est
rate_est = 1 / theta_est
if bias_correction:
concentration_est = concentration_est - 1 / n_total * (
3 * concentration_est
- 2 / 3 * (concentration_est / (1 + concentration_est))
- 4 / 5 * (concentration_est / (1 + concentration_est) ** 2)
)
rate_est = rate_est * ((n_total - 1) / n_total)
# Handle edge cases before broadcasting
concentration_est = _handle_mle_edge_cases(concentration_est, lb=0.0)
rate_est = _handle_mle_edge_cases(rate_est, lb=0.0)
return {"concentration": concentration_est, "rate": rate_est}
def _set_mle_parameters(self, params_dict: dict[str, Tensor]) -> None:
"""Set MLE-estimated parameters for Gamma distribution.
Explicitly handles the two parameter types:
- concentration: Property with setter, calls property setter which updates log_concentration
- rate: Property with setter, calls property setter which updates log_rate
Args:
params_dict: Dictionary with 'concentration' and 'rate' parameter values.
"""
self.concentration = params_dict["concentration"] # Uses property setter
self.rate = params_dict["rate"] # Uses property setter