Source code for spflow.modules.leaves.uniform

from __future__ import annotations

import torch
from torch import Tensor

from spflow.exceptions import InvalidParameterCombinationError, InvalidParameterError
from spflow.meta.data import Scope
from spflow.modules.leaves.leaf import LeafModule
from spflow.utils.cache import Cache


[docs] class Uniform(LeafModule): """Uniform distribution leaf with fixed interval bounds. Note: Interval bounds are fixed buffers and cannot be learned. Attributes: start: Start of interval (fixed buffer). end: End of interval (fixed buffer). end_next: Next representable value after end. support_outside: Whether values outside [start, end] are supported. distribution: Underlying torch.distributions.Uniform. """
[docs] def __init__( self, scope: Scope, out_channels: int | None = None, num_repetitions: int = 1, low: Tensor | None = None, high: Tensor | None = None, validate_args: bool | None = True, support_outside: bool = True, ): """Initialize Uniform distribution leaf. Args: scope: Variable scope. out_channels: Number of output channels (inferred from params if None). num_repetitions: Number of repetitions. low: Lower bound tensor (required). high: Upper bound tensor (required). validate_args: Whether to enable torch.distributions argument validation. support_outside: Whether values outside [start, end] are supported. """ if low is None or high is None: raise InvalidParameterCombinationError( "'low' and 'high' parameters are required for Uniform distribution" ) if not torch.isfinite(low).all() or not torch.isfinite(high).all(): raise InvalidParameterError("Parameter must be finite") super().__init__( scope=scope, out_channels=out_channels, # type: ignore num_repetitions=num_repetitions, params=[low, high], validate_args=validate_args, ) # Register interval bounds as torch buffers (should not be changed) self.register_buffer("low", torch.empty(size=[])) self.register_buffer("high", torch.empty(size=[])) self.register_buffer("end_next", torch.empty(size=[])) self.register_buffer("_support_outside", torch.empty(size=[])) self.low = low self.high = high self.end_next = torch.nextafter(high, high.new_tensor(float("inf"))) self._support_outside = high.new_tensor(support_outside, dtype=torch.bool)
@property def _supported_value(self): """Fallback value for unsupported data.""" return self.low @property def _torch_distribution_class(self) -> type[torch.distributions.Uniform]: return torch.distributions.Uniform @property def mode(self) -> Tensor: """Returns the mode (midpoint) of the distribution.""" return (self.low + self.high) / 2
[docs] def params(self) -> dict[str, Tensor]: """Returns distribution parameters.""" return {"low": self.low, "high": self.high}
def _compute_parameter_estimates( self, data: Tensor, weights: Tensor, bias_correction: bool ) -> dict[str, Tensor]: """Compute raw MLE estimates for uniform distribution (without broadcasting). Note: For Uniform distribution, this is a no-op since parameters are fixed buffers. This method exists to maintain consistency with other leaf distributions. Args: data: Input data tensor. weights: Weight tensor for each data point. bias_correction: Whether to apply bias correction (unused for Uniform). Returns: Empty dictionary (no parameters to estimate). """ return {} def _set_mle_parameters(self, params_dict: dict[str, Tensor]) -> None: """Set MLE-estimated parameters for Uniform distribution. Note: For Uniform distribution, this is a no-op since parameters are fixed buffers. The low and high bounds cannot be updated through MLE. This method exists to maintain consistency with other leaf distributions. Args: params_dict: Dictionary with parameter values (empty for Uniform). """ pass
[docs] def expectation_maximization( self, data: torch.Tensor, bias_correction: bool = False, cache: Cache | None = None, ) -> None: pass