Source code for spflow.modules.leaves.categorical
import torch
from torch import Tensor, nn
from spflow.exceptions import InvalidParameterCombinationError
from spflow.meta.data import Scope
from spflow.modules.leaves.leaf import LeafModule
from spflow.utils.leaves import init_parameter
from spflow.utils.projections import proj_convex_to_real, proj_real_to_convex
[docs]
class Categorical(LeafModule):
"""Categorical distribution leaf for discrete choice over K categories.
Attributes:
p: Categorical probabilities (normalized, includes extra dimension for K).
K: Number of categories.
distribution: Underlying torch.distributions.Categorical.
"""
[docs]
def __init__(
self,
scope: Scope,
out_channels: int | None = None,
num_repetitions: int = 1,
K: int | None = None,
probs: Tensor | None = None,
logits: Tensor | None = None,
parameter_fn: nn.Module | None = None,
validate_args: bool | None = True,
):
"""Initialize Categorical distribution leaf module.
Args:
scope: The scope of the distribution.
out_channels: Number of output channels (inferred from params if None).
num_repetitions: Number of repetitions for the distribution.
K: Number of categories (optional if parameter tensor provided).
probs: Probability tensor over categories.
logits: Logits tensor over categories.
parameter_fn: Optional neural network for parameter generation.
validate_args: Whether to enable torch.distributions argument validation.
"""
# K can be inferred from provided tensor if available
if K is None and probs is None and logits is None:
raise InvalidParameterCombinationError(
"Either 'K' or one of probs/logits must be provided for Categorical distribution"
)
if probs is not None and logits is not None:
raise InvalidParameterCombinationError("Categorical accepts either probs or logits, not both.")
param_source = logits if logits is not None else probs
if K is None and param_source is not None:
# Infer K from the last dimension of p
K = int(param_source.shape[-1])
assert K is not None, "K must be provided or inferred from params"
super().__init__(
scope=scope,
out_channels=out_channels, # type: ignore
num_repetitions=num_repetitions,
params=[param_source],
parameter_fn=parameter_fn, # type: ignore
validate_args=validate_args,
)
self.K: int = K
# Initialize parameter with K categories
param_shape = (*self._event_shape, K)
init_value = init_parameter(
param=param_source,
event_shape=param_shape,
init=lambda shape: torch.rand(shape).softmax(dim=-1),
)
logits_tensor = init_value if logits is not None else proj_convex_to_real(init_value)
self._logits = nn.Parameter(logits_tensor)
@property
def probs(self) -> Tensor:
"""Categorical probabilities in natural space (read via softmax of logits)."""
return proj_real_to_convex(self._logits)
@probs.setter
def probs(self, value: Tensor) -> None:
"""Set categorical probabilities (stores as logits)."""
value_tensor = torch.as_tensor(value, dtype=self._logits.dtype, device=self._logits.device)
self._logits.data = proj_convex_to_real(value_tensor)
@property
def logits(self) -> Tensor:
"""Logits directly parameterizing the categorical distribution."""
return self._logits
@logits.setter
def logits(self, value: Tensor) -> None:
value_tensor = torch.as_tensor(value, dtype=self._logits.dtype, device=self._logits.device)
self._logits.data = value_tensor
@property
def _supported_value(self):
"""Fallback value for unsupported data."""
return 1
@property
def _torch_distribution_class(self) -> type[torch.distributions.Categorical]:
return torch.distributions.Categorical
[docs]
def params(self) -> dict[str, Tensor]:
"""Returns distribution parameters."""
return {"logits": self.logits}
def _compute_parameter_estimates(
self, data: Tensor, weights: Tensor, bias_correction: bool
) -> dict[str, Tensor]:
"""Compute raw MLE estimates for categorical distribution (without broadcasting).
Args:
data: Input data tensor.
weights: Weight tensor for each data point.
bias_correction: Not used for Categorical (included for interface consistency).
Returns:
Dictionary with 'probs' estimates (shape: out_features x K).
"""
n_total = weights.sum(dim=0)
if self.K is not None:
num_categories = self.K
else:
finite_values = data[~torch.isnan(data)]
num_categories = int(finite_values.max().item()) + 1 if finite_values.numel() else 1
p_est = torch.empty_like(self.probs)
for cat in range(num_categories):
cat_mask = (data == cat).float()
p_est[..., cat] = torch.sum(weights * cat_mask, dim=0) / n_total
# Handle edge cases (NaN or invalid probabilities) before broadcasting
# For categorical, we ensure probabilities sum to 1 and are non-negative
p_est = torch.clamp(p_est, min=1e-10) # Avoid zero probabilities
p_est = p_est / p_est.sum(dim=-1, keepdim=True) # Renormalize across categories
return {"probs": p_est}
def _set_mle_parameters(self, params_dict: dict[str, Tensor]) -> None:
"""Set MLE-estimated parameters for Categorical distribution.
Explicitly handles the parameter type:
- probs: Property with setter, calls property setter which updates _logits
Args:
params_dict: Dictionary with 'probs' parameter values.
"""
self.probs = params_dict["probs"] # Uses property setter