Source code for spflow.modules.leaves.uniform

from __future__ import annotations

import torch
from einops import rearrange
from torch import Tensor

from spflow.exceptions import InvalidParameterCombinationError, InvalidParameterError
from spflow.meta.data import Scope
from spflow.modules.leaves.leaf import LeafModule
from spflow.utils.cache import Cache


[docs] class Uniform(LeafModule): """Uniform distribution leaf with fixed interval bounds. Note: Interval bounds are fixed buffers and cannot be learned. Attributes: start: Start of interval (fixed buffer). end: End of interval (fixed buffer). end_next: Next representable value after end. support_outside: Whether values outside [start, end] are supported. distribution: Underlying torch.distributions.Uniform. """
[docs] def __init__( self, scope: Scope, out_channels: int = 1, num_repetitions: int = 1, low: Tensor | None = None, high: Tensor | None = None, validate_args: bool | None = True, support_outside: bool = True, ): """Initialize Uniform distribution leaf. Args: scope: Variable scope. out_channels: Number of output channels (inferred from params if None). num_repetitions: Number of repetitions. low: Lower bound tensor (required). high: Upper bound tensor (required). validate_args: Whether to enable torch.distributions argument validation. support_outside: Whether values outside [start, end] are supported. """ if low is None or high is None: raise InvalidParameterCombinationError( "'low' and 'high' parameters are required for Uniform distribution" ) if not torch.isfinite(low).all() or not torch.isfinite(high).all(): raise InvalidParameterError("Parameter must be finite") super().__init__( scope=scope, out_channels=out_channels, # type: ignore num_repetitions=num_repetitions, params=[low, high], validate_args=validate_args, ) # Register interval bounds as torch buffers (should not be changed) self.register_buffer("low", torch.empty(size=[])) self.register_buffer("high", torch.empty(size=[])) self.register_buffer("end_next", torch.empty(size=[])) self.register_buffer("_support_outside", torch.empty(size=[])) self.low = low self.high = high self.end_next = torch.nextafter(high, high.new_tensor(float("inf"))) self._support_outside = high.new_tensor(support_outside, dtype=torch.bool)
@property def _supported_value(self): """Fallback value for unsupported data.""" return self.low @property def _torch_distribution_class(self) -> type[torch.distributions.Uniform]: return torch.distributions.Uniform @property def _torch_distribution_class_with_differentiable_sampling( self, ) -> type[torch.distributions.Uniform]: return torch.distributions.Uniform @property def mode(self) -> Tensor: """Returns the mode (midpoint) of the distribution.""" return (self.low + self.high) / 2
[docs] def params(self) -> dict[str, Tensor]: """Returns distribution parameters.""" return {"low": self.low, "high": self.high}
def _compute_parameter_estimates( self, data: Tensor, weights: Tensor, bias_correction: bool ) -> dict[str, Tensor]: """Compute raw MLE estimates for uniform distribution (without broadcasting). Note: For Uniform distribution, this is a no-op since parameters are fixed buffers. This method exists to maintain consistency with other leaf distributions. Args: data: Input data tensor. weights: Weight tensor for each data point. bias_correction: Whether to apply bias correction (unused for Uniform). Returns: Empty dictionary (no parameters to estimate). """ return {} def _set_mle_parameters(self, params_dict: dict[str, Tensor]) -> None: """Set MLE-estimated parameters for Uniform distribution. Note: For Uniform distribution, this is a no-op since parameters are fixed buffers. The low and high bounds cannot be updated through MLE. This method exists to maintain consistency with other leaf distributions. Args: params_dict: Dictionary with parameter values (empty for Uniform). """ pass def _log_likelihood_interval( self, low: Tensor, high: Tensor, cache: Cache | None = None, ) -> Tensor: """Compute log P(low <= X <= high) for interval evidence. Args: low: Lower bounds of shape (batch, features). NaN = no lower bound. high: Upper bounds of shape (batch, features). NaN = no upper bound. cache: Optional cache dictionary. Returns: Log-likelihood tensor. """ # Get scope-filtered bounds low_scoped = low[:, self.scope.query] high_scoped = high[:, self.scope.query] # Expand to match (batch, features, channels, repetitions) low_expanded = rearrange(low_scoped, "b f -> b f 1 1") high_expanded = rearrange(high_scoped, "b f -> b f 1 1") # Distribution bounds a = self.low # (features, channels, reps) or scalar b = self.high # Handle NaN bounds (treat as -inf/+inf → clamp to distribution support) effective_low = torch.where(torch.isnan(low_expanded), a, torch.maximum(low_expanded, a)) effective_high = torch.where(torch.isnan(high_expanded), b, torch.minimum(high_expanded, b)) # Compute interval probability: (effective_high - effective_low) / (b - a) interval_length = torch.clamp(effective_high - effective_low, min=0.0) support_length = b - a prob = interval_length / support_length return torch.log(prob) def _expectation_maximization_step( self, data: torch.Tensor, bias_correction: bool = True, *, cache: Cache, ) -> None: del data, bias_correction, cache
class _Uniform(torch.distributions.Uniform): def mode(self): # We deviate from torch here, since torch returns NaN raise NotImplementedError("Mode is not defined (not unique) for Uniform distributions")