Source code for spflow.modules.leaves.categorical

import torch
from einops import rearrange
from torch import Tensor, nn

from spflow.exceptions import InvalidParameterCombinationError
from spflow.meta.data import Scope
from spflow.modules.leaves.leaf import LeafModule
from spflow.utils.cache import Cache
from spflow.utils.leaves import init_parameter
from spflow.utils.projections import proj_convex_to_real, proj_real_to_convex
from spflow.utils.sampling_context import SIMPLE


[docs] class Categorical(LeafModule): """Categorical distribution leaf for discrete choice over K categories. Attributes: p: Categorical probabilities (normalized, includes extra dimension for K). K: Number of categories. distribution: Underlying torch.distributions.Categorical. """
[docs] def __init__( self, scope: Scope, out_channels: int = 1, num_repetitions: int = 1, K: int | None = None, probs: Tensor | None = None, logits: Tensor | None = None, parameter_fn: nn.Module | None = None, validate_args: bool | None = True, ): """Initialize Categorical distribution leaf module. Args: scope: The scope of the distribution. out_channels: Number of output channels (inferred from params if None). num_repetitions: Number of repetitions for the distribution. K: Number of categories (optional if parameter tensor provided). probs: Probability tensor over categories. logits: Logits tensor over categories. parameter_fn: Optional neural network for parameter generation. validate_args: Whether to enable torch.distributions argument validation. """ # K can be inferred from provided tensor if available if K is None and probs is None and logits is None: raise InvalidParameterCombinationError( "Either 'K' or one of probs/logits must be provided for Categorical distribution" ) if probs is not None and logits is not None: raise InvalidParameterCombinationError("Categorical accepts either probs or logits, not both.") param_source = logits if logits is not None else probs if K is None and param_source is not None: # Infer K from the last dimension of p K = int(param_source.shape[-1]) assert K is not None, "K must be provided or inferred from params" super().__init__( scope=scope, out_channels=out_channels, # type: ignore num_repetitions=num_repetitions, params=[param_source], parameter_fn=parameter_fn, # type: ignore validate_args=validate_args, ) self.K: int = K # Initialize parameter with K categories param_shape = (*self._event_shape, K) init_value = init_parameter( param=param_source, event_shape=param_shape, init=lambda shape: torch.rand(shape).softmax(dim=-1), ) logits_tensor = init_value if logits is not None else proj_convex_to_real(init_value) self._logits = nn.Parameter(logits_tensor)
@property def probs(self) -> Tensor: """Categorical probabilities in natural space (read via softmax of logits).""" return proj_real_to_convex(self._logits) @probs.setter def probs(self, value: Tensor) -> None: """Set categorical probabilities (stores as logits).""" value_tensor = torch.as_tensor(value, dtype=self._logits.dtype, device=self._logits.device) self._logits.data = proj_convex_to_real(value_tensor) @property def logits(self) -> Tensor: """Logits directly parameterizing the categorical distribution.""" return self._logits @logits.setter def logits(self, value: Tensor) -> None: value_tensor = torch.as_tensor(value, dtype=self._logits.dtype, device=self._logits.device) self._logits.data = value_tensor @property def _supported_value(self): """Fallback value for unsupported data.""" return 1 @property def _torch_distribution_class(self) -> type[torch.distributions.Categorical]: return torch.distributions.Categorical @property def _torch_distribution_class_with_differentiable_sampling( self, ) -> type[torch.distributions.Distribution]: return CategoricalWithDifferentiableSampling
[docs] def params(self) -> dict[str, Tensor]: """Returns distribution parameters.""" return {"logits": self.logits}
def _compute_parameter_estimates( self, data: Tensor, weights: Tensor, bias_correction: bool ) -> dict[str, Tensor]: """Compute raw MLE estimates for categorical distribution (without broadcasting). Args: data: Input data tensor. weights: Weight tensor for each data point. bias_correction: Not used for Categorical (included for interface consistency). Returns: Dictionary with 'probs' estimates (shape: out_features x K). """ n_total = weights.sum(dim=0) if self.K is not None: num_categories = self.K else: finite_values = data[~torch.isnan(data)] num_categories = int(finite_values.max().item()) + 1 if finite_values.numel() else 1 p_est = torch.empty_like(self.probs) for cat in range(num_categories): cat_mask = (data == cat).float() p_est[..., cat] = torch.sum(weights * cat_mask, dim=0) / n_total # Handle edge cases (NaN or invalid probabilities) before broadcasting # For categorical, we ensure probabilities sum to 1 and are non-negative p_est = torch.clamp(p_est, min=1e-10) # Avoid zero probabilities p_est = p_est / p_est.sum(dim=-1, keepdim=True) # Renormalize across categories return {"probs": p_est} def _set_mle_parameters(self, params_dict: dict[str, Tensor]) -> None: """Set MLE-estimated parameters for Categorical distribution. Explicitly handles the parameter type: - probs: Property with setter, calls property setter which updates _logits Args: params_dict: Dictionary with 'probs' parameter values. """ self.probs = params_dict["probs"] # Uses property setter def _log_likelihood_interval( self, low: Tensor, high: Tensor, cache: Cache | None = None, ) -> Tensor: """Compute log P(low <= X <= high) for interval evidence. Sums probabilities for all categories k such that low <= k <= high. Args: low: Lower bounds of shape (batch, features). high: Upper bounds of shape (batch, features). cache: Optional cache dictionary. Returns: Log-likelihood tensor. """ # Get probs (features, channels, repetitions, K) probs = self.probs # Get scope-filtered bounds low_scoped = low[:, self.scope.query] high_scoped = high[:, self.scope.query] # Expand to match (batch, features, channels, repetitions) low_expanded = rearrange(low_scoped, "b f -> b f 1 1") high_expanded = rearrange(high_scoped, "b f -> b f 1 1") # Handle NaN bounds low_processed = torch.where( torch.isnan(low_expanded), torch.zeros_like(low_expanded), torch.ceil(low_expanded), ) high_processed = torch.where( torch.isnan(high_expanded), torch.tensor(float(self.K - 1), device=high.device, dtype=high.dtype), torch.floor(high_expanded), ) # Create category indices tensor [K] K = self.K categories = torch.arange(K, device=probs.device).float() # Reshape categories to [1, 1, 1, 1, K] for broadcasting # We need to broadcast against [batch, features, channels, repetitions] categories = rearrange(categories, "k -> 1 1 1 1 k") # Reshape bounds to [batch, features, channels, repetitions, 1] low_b = rearrange(low_processed, "b f ci r -> b f ci r 1") high_b = rearrange(high_processed, "b f ci r -> b f ci r 1") # Mask of valid categories: low <= cat <= high mask = (categories >= low_b) & (categories <= high_b) # Sum probabilities of valid categories # probs shape: [features, channels, reps, K] -> need to unsqueeze batch probs_expanded = rearrange(probs, "f ci r k -> 1 f ci r k") # multiply by mask [batch, features, channels, reps, K] valid_probs = probs_expanded * mask # Sum over categories (last dim) total_prob = valid_probs.sum(dim=-1) return torch.log(torch.clamp(total_prob, min=1e-40))
class CategoricalWithDifferentiableSampling(torch.distributions.Categorical): """Categorical distribution with differentiable rsample via SIMPLE. Returns (straight-through) hard category indices as floating point values. """ has_rsample = True def sample(self, sample_shape: torch.Size = torch.Size()) -> Tensor: return self.rsample(sample_shape) def rsample(self, sample_shape: torch.Size = torch.Size()) -> Tensor: sample_shape = torch.Size(sample_shape) logits = self.logits if sample_shape: logits = logits.expand(*sample_shape, *logits.shape) samples_oh = SIMPLE(logits=logits, dim=-1, is_mpe=False) k = torch.arange(logits.shape[-1], device=logits.device, dtype=logits.dtype) return (samples_oh * k).sum(dim=-1)