import torch
from einops import rearrange
from torch import Tensor, nn
from spflow.exceptions import InvalidParameterCombinationError
from spflow.meta.data import Scope
from spflow.modules.leaves.leaf import LeafModule
from spflow.utils.cache import Cache
from spflow.utils.leaves import init_parameter
from spflow.utils.projections import proj_convex_to_real, proj_real_to_convex
from spflow.utils.sampling_context import SIMPLE
[docs]
class Categorical(LeafModule):
"""Categorical distribution leaf for discrete choice over K categories.
Attributes:
p: Categorical probabilities (normalized, includes extra dimension for K).
K: Number of categories.
distribution: Underlying torch.distributions.Categorical.
"""
[docs]
def __init__(
self,
scope: Scope,
out_channels: int = 1,
num_repetitions: int = 1,
K: int | None = None,
probs: Tensor | None = None,
logits: Tensor | None = None,
parameter_fn: nn.Module | None = None,
validate_args: bool | None = True,
):
"""Initialize Categorical distribution leaf module.
Args:
scope: The scope of the distribution.
out_channels: Number of output channels (inferred from params if None).
num_repetitions: Number of repetitions for the distribution.
K: Number of categories (optional if parameter tensor provided).
probs: Probability tensor over categories.
logits: Logits tensor over categories.
parameter_fn: Optional neural network for parameter generation.
validate_args: Whether to enable torch.distributions argument validation.
"""
# K can be inferred from provided tensor if available
if K is None and probs is None and logits is None:
raise InvalidParameterCombinationError(
"Either 'K' or one of probs/logits must be provided for Categorical distribution"
)
if probs is not None and logits is not None:
raise InvalidParameterCombinationError("Categorical accepts either probs or logits, not both.")
param_source = logits if logits is not None else probs
if K is None and param_source is not None:
# Infer K from the last dimension of p
K = int(param_source.shape[-1])
assert K is not None, "K must be provided or inferred from params"
super().__init__(
scope=scope,
out_channels=out_channels, # type: ignore
num_repetitions=num_repetitions,
params=[param_source],
parameter_fn=parameter_fn, # type: ignore
validate_args=validate_args,
)
self.K: int = K
# Initialize parameter with K categories
param_shape = (*self._event_shape, K)
init_value = init_parameter(
param=param_source,
event_shape=param_shape,
init=lambda shape: torch.rand(shape).softmax(dim=-1),
)
logits_tensor = init_value if logits is not None else proj_convex_to_real(init_value)
self._logits = nn.Parameter(logits_tensor)
@property
def probs(self) -> Tensor:
"""Categorical probabilities in natural space (read via softmax of logits)."""
return proj_real_to_convex(self._logits)
@probs.setter
def probs(self, value: Tensor) -> None:
"""Set categorical probabilities (stores as logits)."""
value_tensor = torch.as_tensor(value, dtype=self._logits.dtype, device=self._logits.device)
self._logits.data = proj_convex_to_real(value_tensor)
@property
def logits(self) -> Tensor:
"""Logits directly parameterizing the categorical distribution."""
return self._logits
@logits.setter
def logits(self, value: Tensor) -> None:
value_tensor = torch.as_tensor(value, dtype=self._logits.dtype, device=self._logits.device)
self._logits.data = value_tensor
@property
def _supported_value(self):
"""Fallback value for unsupported data."""
return 1
@property
def _torch_distribution_class(self) -> type[torch.distributions.Categorical]:
return torch.distributions.Categorical
@property
def _torch_distribution_class_with_differentiable_sampling(
self,
) -> type[torch.distributions.Distribution]:
return CategoricalWithDifferentiableSampling
[docs]
def params(self) -> dict[str, Tensor]:
"""Returns distribution parameters."""
return {"logits": self.logits}
def _compute_parameter_estimates(
self, data: Tensor, weights: Tensor, bias_correction: bool
) -> dict[str, Tensor]:
"""Compute raw MLE estimates for categorical distribution (without broadcasting).
Args:
data: Input data tensor.
weights: Weight tensor for each data point.
bias_correction: Not used for Categorical (included for interface consistency).
Returns:
Dictionary with 'probs' estimates (shape: out_features x K).
"""
n_total = weights.sum(dim=0)
if self.K is not None:
num_categories = self.K
else:
finite_values = data[~torch.isnan(data)]
num_categories = int(finite_values.max().item()) + 1 if finite_values.numel() else 1
p_est = torch.empty_like(self.probs)
for cat in range(num_categories):
cat_mask = (data == cat).float()
p_est[..., cat] = torch.sum(weights * cat_mask, dim=0) / n_total
# Handle edge cases (NaN or invalid probabilities) before broadcasting
# For categorical, we ensure probabilities sum to 1 and are non-negative
p_est = torch.clamp(p_est, min=1e-10) # Avoid zero probabilities
p_est = p_est / p_est.sum(dim=-1, keepdim=True) # Renormalize across categories
return {"probs": p_est}
def _set_mle_parameters(self, params_dict: dict[str, Tensor]) -> None:
"""Set MLE-estimated parameters for Categorical distribution.
Explicitly handles the parameter type:
- probs: Property with setter, calls property setter which updates _logits
Args:
params_dict: Dictionary with 'probs' parameter values.
"""
self.probs = params_dict["probs"] # Uses property setter
def _log_likelihood_interval(
self,
low: Tensor,
high: Tensor,
cache: Cache | None = None,
) -> Tensor:
"""Compute log P(low <= X <= high) for interval evidence.
Sums probabilities for all categories k such that low <= k <= high.
Args:
low: Lower bounds of shape (batch, features).
high: Upper bounds of shape (batch, features).
cache: Optional cache dictionary.
Returns:
Log-likelihood tensor.
"""
# Get probs (features, channels, repetitions, K)
probs = self.probs
# Get scope-filtered bounds
low_scoped = low[:, self.scope.query]
high_scoped = high[:, self.scope.query]
# Expand to match (batch, features, channels, repetitions)
low_expanded = rearrange(low_scoped, "b f -> b f 1 1")
high_expanded = rearrange(high_scoped, "b f -> b f 1 1")
# Handle NaN bounds
low_processed = torch.where(
torch.isnan(low_expanded),
torch.zeros_like(low_expanded),
torch.ceil(low_expanded),
)
high_processed = torch.where(
torch.isnan(high_expanded),
torch.tensor(float(self.K - 1), device=high.device, dtype=high.dtype),
torch.floor(high_expanded),
)
# Create category indices tensor [K]
K = self.K
categories = torch.arange(K, device=probs.device).float()
# Reshape categories to [1, 1, 1, 1, K] for broadcasting
# We need to broadcast against [batch, features, channels, repetitions]
categories = rearrange(categories, "k -> 1 1 1 1 k")
# Reshape bounds to [batch, features, channels, repetitions, 1]
low_b = rearrange(low_processed, "b f ci r -> b f ci r 1")
high_b = rearrange(high_processed, "b f ci r -> b f ci r 1")
# Mask of valid categories: low <= cat <= high
mask = (categories >= low_b) & (categories <= high_b)
# Sum probabilities of valid categories
# probs shape: [features, channels, reps, K] -> need to unsqueeze batch
probs_expanded = rearrange(probs, "f ci r k -> 1 f ci r k")
# multiply by mask [batch, features, channels, reps, K]
valid_probs = probs_expanded * mask
# Sum over categories (last dim)
total_prob = valid_probs.sum(dim=-1)
return torch.log(torch.clamp(total_prob, min=1e-40))
class CategoricalWithDifferentiableSampling(torch.distributions.Categorical):
"""Categorical distribution with differentiable rsample via SIMPLE.
Returns (straight-through) hard category indices as floating point values.
"""
has_rsample = True
def sample(self, sample_shape: torch.Size = torch.Size()) -> Tensor:
return self.rsample(sample_shape)
def rsample(self, sample_shape: torch.Size = torch.Size()) -> Tensor:
sample_shape = torch.Size(sample_shape)
logits = self.logits
if sample_shape:
logits = logits.expand(*sample_shape, *logits.shape)
samples_oh = SIMPLE(logits=logits, dim=-1, is_mpe=False)
k = torch.arange(logits.shape[-1], device=logits.device, dtype=logits.dtype)
return (samples_oh * k).sum(dim=-1)