from __future__ import annotations
import torch
from einops import rearrange, repeat
from torch import Tensor, nn
from spflow.exceptions import InvalidParameterCombinationError, InvalidParameterError
from spflow.meta.data import Scope
from spflow.modules.leaves.leaf import LeafModule
from spflow.utils.leaves import init_parameter
from spflow.utils.projections import proj_convex_to_real, proj_real_to_convex
from spflow.utils.sampling_context import SIMPLE
class HistogramDist:
"""Piecewise-constant histogram density with fixed bin edges.
This distribution models a continuous univariate random variable as a
piecewise-constant density over fixed bin edges. It mirrors the subset of
the ``torch.distributions`` interface required by :class:`~spflow.modules.leaves.leaf.LeafModule`:
``log_prob``, ``sample``, and ``mode``.
Shape conventions:
- Parameters have batch shape ``(F, C, R, B)`` where:
``F``=features (must be 1 for Histogram leaves), ``C``=channels,
``R``=repetitions, ``B``=number of bins.
- ``sample((N,))`` returns ``(N, F, C, R)``.
- ``log_prob(x)`` returns ``(N, F, C, R)``.
"""
def __init__(self, *, bin_edges: Tensor, logits: Tensor, min_prob: float = 1e-12) -> None:
if bin_edges.dim() != 1:
raise InvalidParameterError(f"bin_edges must be 1D, got shape {tuple(bin_edges.shape)}.")
if bin_edges.numel() < 2:
raise InvalidParameterError("bin_edges must contain at least two edges.")
if not torch.isfinite(bin_edges).all():
raise InvalidParameterError("bin_edges must be finite.")
if not torch.all(bin_edges[1:] > bin_edges[:-1]):
raise InvalidParameterError("bin_edges must be strictly increasing.")
if logits.dim() != 4:
raise InvalidParameterError(f"logits must be 4D (F,C,R,B), got shape {tuple(logits.shape)}.")
self._bin_edges = bin_edges
self._logits = logits
self._min_prob = float(min_prob)
@property
def bin_edges(self) -> Tensor:
return self._bin_edges
@property
def nbins(self) -> int:
return int(self._bin_edges.numel() - 1)
@property
def probs(self) -> Tensor:
return torch.softmax(self._logits, dim=-1)
@property
def _bin_widths(self) -> Tensor:
return self._bin_edges[1:] - self._bin_edges[:-1]
@property
def _bin_midpoints(self) -> Tensor:
return (self._bin_edges[:-1] + self._bin_edges[1:]) / 2.0
@property
def _bin_densities(self) -> Tensor:
widths = self._bin_widths
probs = self.probs
densities = probs / rearrange(widths, "b -> 1 1 1 b")
return densities
@property
def mode(self) -> Tensor:
"""Return mode as bin midpoint of the maximum-probability bin."""
max_idx = torch.argmax(self.probs, dim=-1) # (F, C, R)
mids = self._bin_midpoints.to(device=max_idx.device, dtype=self._logits.dtype) # (B,)
return mids[max_idx] # (F, C, R)
def _align_x(self, x: Tensor) -> Tensor:
if x.dim() == 2:
return rearrange(x, "n f -> n f 1 1")
if x.dim() == 3:
return rearrange(x, "n f c -> n f c 1")
if x.dim() == 4:
return x
raise InvalidParameterError(
f"Expected x to have shape (N,F), (N,F,C), or (N,F,C,R); got {tuple(x.shape)}."
)
def log_prob(self, x: Tensor) -> Tensor:
"""Compute log probability density.
Values outside the support ``[bin_edges[0], bin_edges[-1])`` receive ``-inf``.
"""
x = self._align_x(x)
_, num_features, _, _ = x.shape
if num_features != self._logits.shape[0]:
raise InvalidParameterError(
f"Feature mismatch: x has {num_features} features but logits have {self._logits.shape[0]}."
)
target_shape = (x.shape[0], *self._logits.shape[:-1])
edges = self._bin_edges.to(device=x.device, dtype=x.dtype)
bin_idx = torch.bucketize(x, edges, right=True) - 1
in_support = torch.isfinite(x) & (x >= edges[0]) & (x < edges[-1])
bin_idx_safe = bin_idx.clamp(0, self.nbins - 1).expand(target_shape)
in_support = in_support.expand(target_shape)
logits = self._logits.to(device=x.device, dtype=x.dtype)
log_widths = self._bin_widths.to(device=x.device, dtype=x.dtype).log()
log_densities = torch.log_softmax(logits, dim=-1) - rearrange(log_widths, "b -> 1 1 1 b")
gathered_log = (
log_densities.unsqueeze(0)
.expand(target_shape[0], -1, -1, -1, -1)
.gather(-1, bin_idx_safe.unsqueeze(-1))
.squeeze(-1)
)
max_width = self._bin_widths.max().to(device=gathered_log.device, dtype=gathered_log.dtype)
min_log_density = torch.log(gathered_log.new_tensor(self._min_prob)) - torch.log(max_width)
gathered_log = gathered_log.clamp_min(min_log_density)
return torch.where(in_support, gathered_log, gathered_log.new_full((), float("-inf")))
def sample(self, sample_shape: torch.Size | tuple[int, ...]) -> Tensor:
"""Sample values, uniformly within sampled bins."""
if isinstance(sample_shape, torch.Size):
n_samples = int(sample_shape[0]) if len(sample_shape) else 1
else:
n_samples = int(sample_shape[0]) if len(sample_shape) else 1
probs = self.probs
f, c, r, b = probs.shape
probs_flat = probs.reshape(-1, b) # (F*C*R, B)
cat = torch.distributions.Categorical(probs=probs_flat)
bin_idx = cat.sample((n_samples,)) # (N, F*C*R)
edges = self._bin_edges.to(device=bin_idx.device, dtype=self._logits.dtype)
left = edges[bin_idx]
right = edges[bin_idx + 1]
u = torch.rand_like(left)
x = left + u * (right - left)
return x.reshape(n_samples, f, c, r)
class HistogramDistWithDifferentiableSampling(HistogramDist):
"""Histogram distribution with differentiable sampling via SIMPLE over bins."""
has_rsample = True
def sample(self, sample_shape: torch.Size | tuple[int, ...]) -> Tensor:
return self.rsample(sample_shape)
def rsample(self, sample_shape: torch.Size | tuple[int, ...]) -> Tensor:
if isinstance(sample_shape, torch.Size):
n_samples = int(sample_shape[0]) if len(sample_shape) else 1
else:
n_samples = int(sample_shape[0]) if len(sample_shape) else 1
logits = self._logits.expand(n_samples, *self._logits.shape) # (N, F, C, R, B)
samples_oh = SIMPLE(logits=logits, dim=-1, is_mpe=False)
edges = self._bin_edges.to(device=logits.device, dtype=logits.dtype)
left_edges = edges[:-1]
right_edges = edges[1:]
left = (samples_oh * rearrange(left_edges, "b -> 1 1 1 1 b")).sum(dim=-1)
right = (samples_oh * rearrange(right_edges, "b -> 1 1 1 1 b")).sum(dim=-1)
u = torch.rand_like(left)
return left + u * (right - left)
[docs]
class Histogram(LeafModule):
"""Histogram leaf distribution (continuous piecewise-constant density).
The histogram uses fixed bin edges and learnable bin probabilities (stored as logits).
Within each bin, the density is constant: ``p(bin) / width(bin)``.
Notes:
- This leaf is **univariate**: ``len(scope.query) == 1``.
- Values outside the support ``[bin_edges[0], bin_edges[-1])`` have log-likelihood ``-inf``.
- NaN values are marginalized out by :meth:`~spflow.modules.leaves.leaf.LeafModule.log_likelihood`
and contribute ``0`` to the log-likelihood.
- MPE returns the midpoint of the maximum-probability bin (per channel/repetition).
"""
def __init__(
self,
scope: Scope,
*,
bin_edges: Tensor,
out_channels: int = 1,
num_repetitions: int = 1,
probs: Tensor | None = None,
logits: Tensor | None = None,
min_prob: float = 1e-12,
validate_args: bool | None = True,
) -> None:
if probs is not None and logits is not None:
raise InvalidParameterCombinationError("Histogram accepts either probs or logits, not both.")
if len(scope.query) != 1:
raise InvalidParameterError(
"Histogram leaf is univariate and requires scope with exactly one query RV."
)
bin_edges = torch.as_tensor(bin_edges, dtype=torch.float32)
if bin_edges.dim() != 1:
raise InvalidParameterError(f"bin_edges must be 1D, got shape {tuple(bin_edges.shape)}.")
if bin_edges.numel() < 2:
raise InvalidParameterError("bin_edges must contain at least two edges.")
if not torch.isfinite(bin_edges).all():
raise InvalidParameterError("bin_edges must be finite.")
if not torch.all(bin_edges[1:] > bin_edges[:-1]):
raise InvalidParameterError("bin_edges must be strictly increasing.")
self._min_prob = float(min_prob)
param_source = logits if logits is not None else probs
super().__init__(
scope=scope,
out_channels=out_channels, # type: ignore[arg-type]
num_repetitions=num_repetitions,
params=[param_source],
validate_args=validate_args,
)
nbins = int(bin_edges.numel() - 1)
if param_source is not None and int(param_source.shape[-1]) != nbins:
raise InvalidParameterError(
f"Last dim of probs/logits must match nbins={nbins}, got {int(param_source.shape[-1])}."
)
self.register_buffer("bin_edges", torch.empty(size=[]))
self.bin_edges = bin_edges
param_shape = (*self._event_shape, nbins)
init_value = init_parameter(
param=param_source,
event_shape=param_shape,
init=lambda shape: torch.rand(shape).softmax(dim=-1),
)
logits_tensor = init_value if logits is not None else proj_convex_to_real(init_value)
self._logits = nn.Parameter(logits_tensor)
@property
def logits(self) -> Tensor:
"""Unconstrained logits parameterizing bin probabilities."""
return self._logits
@logits.setter
def logits(self, value: Tensor) -> None:
value_tensor = torch.as_tensor(value, dtype=self._logits.dtype, device=self._logits.device)
self._logits.data = value_tensor
@property
def probs(self) -> Tensor:
"""Bin probabilities in natural space (softmax of logits)."""
return proj_real_to_convex(self._logits)
@probs.setter
def probs(self, value: Tensor) -> None:
value_tensor = torch.as_tensor(value, dtype=self._logits.dtype, device=self._logits.device)
if not torch.isfinite(value_tensor).all():
raise InvalidParameterError("probs must be finite.")
if (value_tensor < 0).any():
raise InvalidParameterError("probs must be non-negative.")
value_tensor = value_tensor / value_tensor.sum(dim=-1, keepdim=True).clamp_min(self._min_prob)
self._logits.data = proj_convex_to_real(value_tensor.clamp_min(self._min_prob))
@property
def _torch_distribution_class(self):
"""Histogram uses a custom distribution, not a torch.distributions class."""
return None
[docs]
def distribution(self, with_differentiable_sampling: bool = False) -> HistogramDist:
dist_cls = HistogramDistWithDifferentiableSampling if with_differentiable_sampling else HistogramDist
return dist_cls(
bin_edges=self.bin_edges.to(self._logits.device), logits=self._logits, min_prob=self._min_prob
)
@property
def _supported_value(self) -> Tensor:
"""Value in support used for NaN imputation prior to marginalization."""
return (self.bin_edges[0] + self.bin_edges[-1]) / 2.0
[docs]
def params(self) -> dict[str, Tensor]:
return {"logits": self.logits}
def _compute_parameter_estimates(
self, data: Tensor, weights: Tensor, bias_correction: bool
) -> dict[str, Tensor]:
del bias_correction
# data: (N, F=1, 1, 1)
x = rearrange(data, "n f 1 1 -> n f")
w = weights # (N, F, C, R)
if x.dim() != 2 or x.shape[1] != 1:
raise InvalidParameterError(
f"Histogram expects univariate scoped data, got shape {tuple(x.shape)}."
)
edges = self.bin_edges.to(device=x.device, dtype=x.dtype)
nbins = int(edges.numel() - 1)
bin_idx = torch.bucketize(x, edges, right=True) - 1 # (N, 1)
in_support = torch.isfinite(x) & (x >= edges[0]) & (x < edges[-1])
if not in_support.all():
raise InvalidParameterError("MLE data contains values outside histogram support.")
bin_idx = rearrange(bin_idx, "n 1 -> n")
w_flat = rearrange(w[:, 0], "n c r -> n (c r)")
one_hot = torch.nn.functional.one_hot(bin_idx, nbins).to(dtype=w_flat.dtype) # (N, B)
counts = w_flat.transpose(0, 1) @ one_hot # (C*R, B)
probs_est = counts / counts.sum(dim=-1, keepdim=True).clamp_min(self._min_prob)
probs_est = probs_est.clamp_min(self._min_prob)
probs_est = probs_est / probs_est.sum(dim=-1, keepdim=True)
probs_est = rearrange(
probs_est,
"(c r) b -> 1 c r b",
c=self.out_shape.channels,
r=self.out_shape.repetitions,
)
return {"probs": probs_est}
[docs]
def marginalize(self, marg_rvs: list[int], prune: bool = True, cache=None):
del prune, cache
if self.is_conditional:
raise RuntimeError(
f"Marginalization not supported for conditional leaf {self.__class__.__name__}."
)
if any(rv in marg_rvs for rv in self.scope.query):
return None
return Histogram(
scope=self.scope.copy(), bin_edges=self.bin_edges.detach().clone(), logits=self.logits.detach()
)