import torch
from einops import rearrange
from torch import Tensor, nn
from spflow.modules.leaves.leaf import LeafModule
from spflow.utils.leaves import init_parameter, _handle_mle_edge_cases
[docs]
class Normal(LeafModule):
"""Normal (Gaussian) distribution leaf module.
Parameterized by mean μ and standard deviation σ (stored in log-space).
Attributes:
loc: Mean parameter.
std: Standard deviation (accessed via property, stored as log_std).
"""
[docs]
def __init__(
self,
scope,
out_channels: int = 1,
num_repetitions: int = 1,
parameter_fn: nn.Module = None,
validate_args: bool | None = True,
loc: Tensor = None,
scale: Tensor = None,
):
"""Initialize Normal distribution.
Args:
scope: Variable scope. Can be a Scope object, a single integer,
or an iterable of integers (list, tuple, numpy array, torch tensor, etc.).
out_channels: Number of output channels (inferred from params if None).
num_repetitions: Number of repetitions (for 3D event shapes).
parameter_fn: Optional neural network for parameter generation.
loc: Mean tensor μ.
scale: Standard deviation tensor σ > 0.
"""
super().__init__(
scope=scope,
out_channels=out_channels,
num_repetitions=num_repetitions,
params=[loc, scale],
parameter_fn=parameter_fn,
validate_args=validate_args,
)
loc = init_parameter(param=loc, event_shape=self._event_shape, init=torch.randn)
scale = init_parameter(param=scale, event_shape=self._event_shape, init=torch.rand)
self.loc = nn.Parameter(loc)
self.log_scale = nn.Parameter(torch.log(scale))
@property
def scale(self) -> Tensor:
"""Standard deviation in natural space (read via exp of log_std)."""
return torch.exp(self.log_scale)
@scale.setter
def scale(self, value: Tensor) -> None:
"""Set standard deviation (stores as log_std, no validation after init)."""
self.log_scale.data = torch.log(
torch.as_tensor(value, dtype=self.log_scale.dtype, device=self.log_scale.device)
)
@property
def _supported_value(self):
return 0.0
@property
def _torch_distribution_class(self) -> type[torch.distributions.Normal]:
return torch.distributions.Normal
@property
def _torch_distribution_class_with_differentiable_sampling(
self,
) -> type[torch.distributions.Normal]:
return torch.distributions.Normal
[docs]
def params(self):
return {"loc": self.loc, "scale": self.scale}
def _compute_parameter_estimates(
self, data: Tensor, weights: Tensor, bias_correction: bool
) -> dict[str, Tensor]:
"""Compute raw MLE estimates for normal distribution (without broadcasting).
Args:
data: Input data tensor.
weights: Weight tensor for each data point.
bias_correction: Whether to apply bias correction to variance estimate.
Returns:
Dictionary with 'loc' and 'scale' estimates (shape: out_features).
"""
n_total = weights.sum(dim=0)
loc_est = (weights * data).sum(0) / n_total
centered = data - loc_est
var_numerator = (weights * centered.pow(2)).sum(0)
var_denom = n_total - 1 if bias_correction else n_total
scale_est = torch.sqrt(var_numerator / var_denom)
# Handle edge cases (NaN, zero, or near-zero std)
scale_est = _handle_mle_edge_cases(scale_est, lb=0.0)
return {"loc": loc_est, "scale": scale_est}
def _set_mle_parameters(self, params_dict: dict[str, Tensor]) -> None:
"""Set MLE-estimated parameters for Normal distribution.
Explicitly handles the two parameter types:
- loc: Direct nn.Parameter, update .data attribute
- scale: Property with setter, calls property setter which updates log_scale
Args:
params_dict: Dictionary with 'loc' and 'scale' parameter values.
"""
self.loc.data = params_dict["loc"]
self.scale = params_dict["scale"] # Uses property setter
def _log_likelihood_interval(
self,
low: Tensor,
high: Tensor,
cache=None,
) -> Tensor:
"""Compute log P(low <= X <= high) for interval evidence.
Args:
low: Lower bounds of shape (batch, features). NaN = no lower bound.
high: Upper bounds of shape (batch, features). NaN = no upper bound.
cache: Optional cache dictionary.
Returns:
Log-likelihood tensor.
"""
# Get scope-filtered bounds
low_scoped = low[:, self.scope.query]
high_scoped = high[:, self.scope.query]
# Expand to match (batch, features, channels, repetitions)
low_expanded = rearrange(low_scoped, "b f -> b f 1 1")
high_expanded = rearrange(high_scoped, "b f -> b f 1 1")
loc = self.loc
scale = self.scale
# Handle NaN bounds as -inf/+inf
low_for_cdf = torch.where(
torch.isnan(low_expanded),
torch.full_like(low_expanded, float("-inf")),
low_expanded,
)
high_for_cdf = torch.where(
torch.isnan(high_expanded),
torch.full_like(high_expanded, float("inf")),
high_expanded,
)
# Compute CDF values using error function
sqrt2 = 1.4142135623730951
cdf_high = 0.5 * (1 + torch.erf((high_for_cdf - loc) / (scale * sqrt2)))
cdf_low = 0.5 * (1 + torch.erf((low_for_cdf - loc) / (scale * sqrt2)))
# P(low <= X <= high) = CDF(high) - CDF(low)
prob = torch.clamp(cdf_high - cdf_low, min=1e-40) # Numerical stability
return torch.log(prob)