from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Optional, Dict, Callable, Iterable, Union
import numpy as np
import torch
from einops import rearrange, repeat
from torch import Tensor
from spflow.exceptions import MissingCacheError, ShapeError, UnsupportedOperationError
from spflow.meta.data.interval_evidence import IntervalEvidence
from spflow.meta.data.scope import Scope
from spflow.modules.module import Module
from spflow.modules.module_shape import ModuleShape
from spflow.utils.cache import Cache, cached
from spflow.utils.leaves import apply_nan_strategy, parse_leaf_args
from spflow.utils.sampling_context import LeafParamRecord, SamplingContext
from spflow.utils.sampling_context import index_one_hot, index_tensor, repeat_repetition_index
[docs]
class LeafModule(Module, ABC):
[docs]
def __init__(
self,
scope: Scope | int | Iterable[int],
out_channels: int = 1,
num_repetitions: int = 1,
params: list[Tensor | None] | None = None,
parameter_fn: Callable[[Tensor], dict[str, Tensor]] = None,
validate_args: bool | None = True,
):
"""Base class for leaf distribution modules.
Args:
scope: Variable scope. Can be a Scope object, a single integer,
or an iterable of integers (list, tuple, numpy array, torch tensor, etc.).
out_channels: Number of output channels (inferred from params if None).
num_repetitions: Number of repetitions (for 3D event shapes).
params: List of parameter tensors (can include None to trigger random init).
parameter_fn: Optional function that takes evidence and returns distribution parameters as dictionary.
validate_args: Whether to enable torch.distributions argument validation.
"""
super().__init__()
event_shape = parse_leaf_args(
scope=scope,
out_channels=out_channels,
num_repetitions=num_repetitions,
params=params,
)
# If not already a Scope, convert int or list[int] to Scope
if not isinstance(scope, Scope):
scope = Scope(scope)
self.scope = scope.copy()
self._event_shape = event_shape
self.parameter_fn = parameter_fn
self._validate_args = validate_args
# Shape computation from _event_shape
features = self._event_shape[0]
channels = self._event_shape[1] if len(self._event_shape) > 1 else 1
repetitions = self._event_shape[2] if len(self._event_shape) > 2 else 1
self.in_shape = ModuleShape(features, 1, 1)
self.out_shape = ModuleShape(features, channels, repetitions)
@property
def inputs(self) -> Module | Iterable[Module]:
"""Leaf modules do not have inputs."""
raise AttributeError(
"LeafModule does not have 'input' attribute -- this should not have been called."
)
@inputs.setter
def inputs(self, value):
"""Leaf modules do not have inputs."""
raise AttributeError(
"LeafModule does not have 'input' attribute -- this should not have been called."
)
@property
def is_conditional(self):
"""Indicates if the leaf uses a parameter network for conditional parameters."""
return self.parameter_fn is not None
[docs]
def distribution(self, with_differentiable_sampling: bool = False) -> torch.distributions.Distribution:
"""Return this leaf's distribution.
Args:
with_differentiable_sampling: Hook for subclasses to return an alternative
differentiable distribution when sampling requires gradient flow.
Ignored by the base implementation.
"""
return self.__make_distribution(
self.params(), with_differentiable_sampling=with_differentiable_sampling
)
@property
@abstractmethod
def _torch_distribution_class(self) -> type[torch.distributions.Distribution]:
pass
@property
def _torch_distribution_class_with_differentiable_sampling(
self,
) -> type[torch.distributions.Distribution]:
raise NotImplementedError(
f"{self.__class__.__name__} does not implement an alternative distribution with differentiable sampling. "
f"Override _torch_distribution_class_with_differentiable_sampling or set with_differentiable_sampling=False when calling distribution()."
)
def __make_distribution(
self, params: Dict[str, Tensor], with_differentiable_sampling: bool = False
) -> torch.distributions.Distribution:
"""Helper method to create distribution from given parameters.
Args:
params: Dictionary of distribution parameters.
with_differentiable_sampling: Whether to use the differentiable sampling distribution class.
Returns:
torch.distributions.Distribution constructed from the parameters.
"""
if with_differentiable_sampling:
return self._torch_distribution_class_with_differentiable_sampling(validate_args=self._validate_args, **params) # type: ignore[call-arg]
else:
return self._torch_distribution_class(validate_args=self._validate_args, **params) # type: ignore[call-arg]
[docs]
def conditional_distribution(
self, evidence: Tensor, with_differentiable_sampling: bool = False
) -> torch.distributions.Distribution:
"""Generate a conditional distribution from evidence.
Args:
evidence: Evidence tensor for conditioning.
with_differentiable_sampling: Hook for subclasses to return an alternative
distribution with differentiable sampling when needed. Ignored by the base
implementation.
Returns:
torch.distributions.Distribution constructed from conditional parameters.
"""
if evidence is None:
raise ValueError("Evidence tensor must be provided for conditional distribution.")
params = self.parameter_fn(evidence)
return self.__make_distribution(params, with_differentiable_sampling=with_differentiable_sampling)
@property
@abstractmethod
def _supported_value(self) -> Union[float, Tensor]:
"""Return a value in the support of the distribution (for NaN imputation).
Implementations may return either a scalar float or a tensor.
If a tensor is returned, it must either be a scalar tensor or have a
leading feature dimension that matches ``len(self.scope.query)``.
"""
pass
def _supported_value_for_imputation(self, scoped_data: Tensor) -> Tensor:
"""Return a tensor of supported values broadcastable to ``scoped_data``.
This helper normalizes leaf-specific ``_supported_value`` implementations
so that NaN imputation works for leaves whose parameters are stored with
additional channel/repetition dimensions.
Args:
scoped_data: 2D tensor of shape (batch, features) for the leaf scope.
Returns:
Tensor broadcasted to the shape of ``scoped_data``.
Raises:
ShapeError: If the leaf provides a tensor that cannot be aligned to features.
"""
if scoped_data.dim() != 2:
raise ShapeError(
f"Expected scoped_data to be 2D (batch, features), got shape {tuple(scoped_data.shape)}."
)
supported_value = self._supported_value
if isinstance(supported_value, (float, int)):
return torch.full_like(scoped_data, float(supported_value))
if not isinstance(supported_value, Tensor):
raise TypeError(
f"{self.__class__.__name__}._supported_value must be a float or Tensor, got {type(supported_value)}."
)
supported_tensor = supported_value.to(device=scoped_data.device, dtype=scoped_data.dtype)
batch_size = scoped_data.shape[0]
num_features = scoped_data.shape[1]
if supported_tensor.numel() == 1:
return repeat(supported_tensor.reshape(()), "-> b f", b=batch_size, f=num_features)
if supported_tensor.dim() == 1:
if supported_tensor.shape[0] != num_features:
raise ShapeError(
f"{self.__class__.__name__}._supported_value has shape {tuple(supported_tensor.shape)}; "
f"expected ({num_features},)."
)
return repeat(supported_tensor, "f -> b f", b=batch_size)
if supported_tensor.shape[0] != num_features:
raise ShapeError(
f"{self.__class__.__name__}._supported_value has shape {tuple(supported_tensor.shape)}; "
f"expected first dimension to match number of features ({num_features})."
)
feature_values = supported_tensor
while feature_values.dim() > 1:
feature_values = feature_values.select(dim=1, index=0)
if feature_values.shape != (num_features,):
raise ShapeError(
f"{self.__class__.__name__}._supported_value could not be reduced to per-feature values; "
f"got shape {tuple(feature_values.shape)}."
)
return repeat(feature_values, "f -> b f", b=batch_size)
[docs]
@abstractmethod
def params(self) -> Dict[str, Tensor]:
"""Returns the parameters of the distribution."""
pass
[docs]
def mode(self, is_differentiable: bool = False) -> Tensor:
"""Return distribution mode.
Args:
is_differentiable: Whether to return the mode from the differentiable distribution (if supported).
Returns:
Mode of the distribution.
"""
return self.distribution().mode
[docs]
def marginalized_params(self, indices: list[int]) -> dict[str, Tensor]:
"""Return parameters marginalized to specified indices.
Args:
indices: List of indices to marginalize to.
Returns:
Dictionary of marginalized parameters.
"""
return {k: v[indices] for k, v in self.params().items()}
def _mle_update_statistics(self, data: Tensor, weights: Tensor, bias_correction: bool) -> None:
"""Compute and set MLE parameter estimates.
Args:
data: Input data tensor.
weights: Weight tensor for each data point.
bias_correction: Whether to apply bias correction to variance estimate.
"""
data = rearrange(data, "b f -> b f 1 1")
estimates = self._compute_parameter_estimates(data, weights, bias_correction)
self._set_mle_parameters(estimates)
@abstractmethod
def _compute_parameter_estimates(
self, data: Tensor, weights: Tensor, bias_correction: bool
) -> Dict[str, Tensor]:
"""Compute raw MLE parameter estimates without broadcasting.
Used internally by both simple and KMeans clustering paths.
Args:
data: Scope-filtered data.
weights: Normalized weights.
bias_correction: Apply bias correction.
Returns:
Dictionary mapping parameter names to raw estimates (shape: out_features).
"""
pass
def _set_mle_parameters(self, params_dict: Dict[str, Tensor]) -> None:
"""Set MLE-estimated parameters.
This method handles the assignment of estimated parameters, accounting for both
direct nn.Parameter objects and property-based parameters with custom setters.
Args:
params_dict: Dictionary mapping parameter names to their estimated values.
"""
for param_name, param_tensor in params_dict.items():
try:
# Try using property setter (works for properties like 'scale')
setattr(self, param_name, param_tensor)
except TypeError:
# Direct parameter (like 'loc') - update .data attribute
getattr(self, param_name).data = param_tensor
@property
def event_shape(self) -> tuple[int, ...]:
"""Return event shape.
Returns:
Event shape tuple.
"""
if self._event_shape is None:
raise RuntimeError(f"{self.__class__.__name__} has not set _event_shape in __init__")
return self._event_shape
@property
def feature_to_scope(self) -> np.ndarray[Scope]:
"""Return list of scopes per feature.
Returns:
List of Scope objects, one per feature.
"""
scopes = np.empty((self.out_shape.features, self.out_shape.repetitions), dtype=Scope)
for i in range(self.out_shape.features):
for j in range(self.out_shape.repetitions):
scopes[i, j] = Scope([self.scope.query[i]])
return scopes
@property
def device(self) -> torch.device:
"""Return device of first parameter or buffer.
Returns:
Device of the module.
"""
try:
return next(iter(self.parameters())).device
except StopIteration:
return next(iter(self.buffers())).device
def _broadcast_to_event_shape(self, param_est: Tensor) -> Tensor:
"""Broadcast parameter estimate to match event_shape.
Args:
param_est: Parameter estimate tensor to broadcast.
Returns:
Parameter estimate broadcasted to match event_shape.
"""
target_shape = tuple(self.event_shape)
# If the parameter already matches the event shape (possibly with extra trailing dims),
# there is nothing to broadcast. This prevents unsqueezing again during chained calls.
if tuple(param_est.shape[: len(target_shape)]) == target_shape:
return param_est
if len(target_shape) == 2:
param_est = rearrange(param_est, "f ... -> f 1 ...")
param_est = param_est.repeat(1, self.out_shape.channels, *([1] * (param_est.dim() - 2)))
elif len(target_shape) == 3:
param_est = rearrange(param_est, "f ... -> f 1 1 ...")
param_est = param_est.repeat(
1,
self.out_shape.channels,
self.out_shape.repetitions,
*([1] * (param_est.dim() - 3)),
)
return param_est
def _expectation_maximization_step(
self,
data: torch.Tensor,
bias_correction: bool = True,
*,
cache: Cache,
) -> None:
"""Perform single EM step.
Args:
data: Input data tensor.
bias_correction: Whether to apply bias correction for leaf statistics.
cache: Cache dictionary from a preceding forward pass.
"""
# Fixed-parameter leaves (no learnable parameters) are explicit EM no-ops.
if not any(param.requires_grad for param in self.parameters()):
return
with torch.no_grad():
# ----- expectation step -----
# get cached log-likelihood gradients w.r.t. module log-likelihoods
module_lls = cache["log_likelihood"].get(self)
if module_lls is None:
raise MissingCacheError(
"Module log-likelihoods not found in cache. Call log_likelihood first."
)
expectations = module_lls.grad
if expectations is None:
raise RuntimeError(
f"Expected gradient for cached log-likelihood tensor of {self.__class__.__name__}, but found None."
)
expectations += 1e-12 # numerical stability
expectations /= expectations.sum(0, keepdim=True) # Normalize
# ----- maximization step -----
# update parameters through maximum weighted likelihood estimation
self.maximum_likelihood_estimation(
data,
weights=expectations,
bias_correction=bias_correction,
)
# NOTE: since we explicitely override parameters in 'maximum_likelihood_estimation',
# we do not need to zero/None parameter gradients
[docs]
@cached
def log_likelihood(
self,
data: Tensor | IntervalEvidence,
cache: Cache | None = None,
) -> Tensor:
"""Compute log-likelihoods, marginalizing over NaN values.
Args:
data: Input data tensor or IntervalEvidence for range queries.
cache: Optional cache dictionary.
Returns:
Log-likelihood tensor.
"""
# Dispatch on IntervalEvidence
if isinstance(data, IntervalEvidence):
return self._log_likelihood_interval(data.low, data.high, cache)
if data.dim() != 2:
raise ValueError(f"Data must be 2-dimensional (batch, num_features), got shape {data.shape}.")
# get information relevant for the scope
data_q = data[:, self.scope.query].to(device=self.device).clone()
if self.event_shape[0] != len(self.scope.query):
raise RuntimeError(
f"event_shape mismatch for {self.__class__.__name__}: event_shape={self.event_shape}, scope_len={len(self.scope.query)}"
)
# ----- marginalization -----
marg_mask = torch.isnan(data_q)
has_marginalizations = marg_mask.any()
# If there are any marg_ids, set them to 0.0 to ensure that log_prob call is successful
# and doesn't throw errors due to NaNs
if has_marginalizations:
fill_values = self._supported_value_for_imputation(data_q)
data_q[marg_mask] = fill_values[marg_mask]
# ----- log probabilities -----
# Add broadcast dims for channel and repetition axes.
# event_shape is always [features, out_channels, num_repetitions].
data_q = rearrange(data_q, "b f -> b f 1 1")
if self.is_conditional:
# Get evidence
data_e = data[:, self.scope.evidence].to(device=self.device)
dist = self.conditional_distribution(data_e)
else:
dist = self.distribution()
log_prob = dist.log_prob(data_q)
# Marginalize entries - broadcast mask to log_prob shape
if has_marginalizations:
# Expand marg_mask to match log_prob shape by broadcasting
# marg_mask is [batch, features], expand to [batch, features, 1, 1]
marg_mask_for_log_prob = rearrange(marg_mask, "b f -> b f 1 1")
# Broadcast to log_prob shape
marg_mask_for_log_prob = torch.broadcast_to(marg_mask_for_log_prob, log_prob.shape)
log_prob[marg_mask_for_log_prob] = 0.0
# Set marginalized scope data back to NaNs
if has_marginalizations:
marg_mask_for_data = rearrange(marg_mask, "b f -> b f 1 1")
data_q[marg_mask_for_data] = torch.nan
return log_prob
def _log_likelihood_interval(
self,
low: Tensor,
high: Tensor,
cache: Cache | None = None,
) -> Tensor:
"""Compute log P(low <= X <= high) for interval evidence.
Uses the distribution's CDF: P(low <= X <= high) = CDF(high) - CDF(low).
Subclasses can override this for custom implementations.
Args:
low: Lower bounds of shape (batch, features). NaN = no lower bound (-inf).
high: Upper bounds of shape (batch, features). NaN = no upper bound (+inf).
cache: Optional cache dictionary.
Returns:
Log-likelihood tensor.
Raises:
NotImplementedError: If the distribution doesn't support CDF.
"""
# Get scope-filtered bounds
low_scoped = low[:, self.scope.query].to(device=self.device)
high_scoped = high[:, self.scope.query].to(device=self.device)
# Expand to match (batch, features, channels, repetitions)
low_expanded = rearrange(low_scoped, "b f -> b f 1 1")
high_expanded = rearrange(high_scoped, "b f -> b f 1 1")
# Get the distribution
dist = self.distribution()
# Check if distribution has cdf method
if not hasattr(dist, "cdf"):
raise NotImplementedError(
f"{self.__class__.__name__} does not support interval inference. "
f"The underlying distribution {dist.__class__.__name__} has no cdf() method."
)
# Handle NaN bounds as -inf/+inf
low_for_cdf = torch.where(
torch.isnan(low_expanded),
torch.full_like(low_expanded, float("-inf")),
low_expanded,
)
high_for_cdf = torch.where(
torch.isnan(high_expanded),
torch.full_like(high_expanded, float("inf")),
high_expanded,
)
# Compute CDF values
try:
cdf_high = dist.cdf(high_for_cdf)
cdf_low = dist.cdf(low_for_cdf)
except NotImplementedError:
raise NotImplementedError(
f"{self.__class__.__name__} does not support interval inference. "
f"The underlying distribution {dist.__class__.__name__} does not implement cdf()."
)
# P(low <= X <= high) = CDF(high) - CDF(low)
prob = torch.clamp(cdf_high - cdf_low, min=1e-40) # Numerical stability
return torch.log(prob)
def _prepare_mle_data(
self,
data: Tensor,
weights: Tensor | None = None,
nan_strategy: str | Callable | None = None,
) -> tuple[Tensor, Tensor]:
"""Prepare normalized data and weights for MLE computation.
Args:
data: Input data tensor.
weights: Optional sample weights.
nan_strategy: Handle NaN ('ignore', callable, or None).
Returns:
Scope-filtered data and normalized weights.
"""
# Step 1: Select scope-relevant features
scoped_data = data[:, self.scope.query]
if weights is None:
weights = torch.ones(
scoped_data.shape[0],
self.out_shape.features,
self.out_shape.channels,
self.out_shape.repetitions,
device=self.device,
)
# Step 2: Apply NaN strategy (drop/impute)
scoped_data, normalized_weights = apply_nan_strategy(nan_strategy, scoped_data, self.device, weights)
return scoped_data, normalized_weights
[docs]
def maximum_likelihood_estimation(
self,
data: Tensor,
weights: Optional[Tensor] = None,
bias_correction: bool = True,
nan_strategy: str | Callable | None = None,
) -> None:
"""Maximum (weighted) likelihood estimation via template method pattern.
Delegates distribution-specific logic to _mle_compute_statistics() hook.
Weights normalized to sum to N.
Args:
data: Input data tensor.
weights: Optional sample weights.
bias_correction: Apply bias correction.
nan_strategy: Handle NaN ('ignore', callable, or None).
"""
if self.is_conditional:
raise RuntimeError(f"MLE not supported for conditional leaf {self.__class__.__name__}.")
# Step 1: Prepare normalized data and weights
data_prepared, weights_prepared = self._prepare_mle_data(
data=data,
weights=weights,
nan_strategy=nan_strategy,
)
# Step 2: Update distribution-specific statistics
self._mle_update_statistics(data_prepared, weights_prepared, bias_correction)
def _sample(
self,
data: Tensor,
sampling_ctx: SamplingContext,
cache: Cache,
) -> Tensor:
"""Sample from leaf distribution given potential evidence.
Args:
num_samples: Number of samples to generate.
data: Optional evidence tensor.
is_mpe: Perform MPE (mode) instead of sampling.
cache: Optional cache dictionary.
sampling_ctx: Optional sampling context.
Returns:
Sampled data tensor.
"""
# Prepare data tensor
sampling_ctx.validate_sampling_context(
num_samples=data.shape[0],
num_features=self.out_shape.features,
num_channels=self.out_shape.channels,
num_repetitions=self.out_shape.repetitions,
allowed_feature_widths=(1, self.out_shape.features, data.shape[1]),
)
scope_cols = self._resolve_scope_columns(num_features=data.shape[1])
out_of_scope = [idx for idx in range(data.shape[1]) if idx not in scope_cols]
marg_mask = torch.isnan(data)
marg_mask[:, out_of_scope] = False
# Mask that tells us which feature at which sample is relevant and should be sampled
samples_mask = marg_mask
c_idx, ctx_mask = self._slice_sampling_context(
sampling_ctx=sampling_ctx, num_features=data.shape[1], scope_cols=scope_cols
)
samples_mask[:, scope_cols] &= ctx_mask
# Count number of samples to draw
instance_mask = samples_mask.sum(1) > 0
n_samples = instance_mask.sum() # count number of rows which have at least one true value
# Routing can legitimately send zero rows to a branch. In that case, there is nothing to
# sample for this leaf and we should return without touching the data tensor.
if int(n_samples.item()) == 0:
return data
evidence = None
if self.is_conditional:
evidence = data[instance_mask][:, self.scope.evidence]
if sampling_ctx.is_mpe:
if self.is_conditional:
dist = self.conditional_distribution(
evidence=evidence,
with_differentiable_sampling=sampling_ctx.is_differentiable,
)
if not hasattr(dist, "mode"):
raise UnsupportedOperationError(
f"MPE sampling requires a distribution with a 'mode' attribute, but "
f"{dist.__class__.__name__} does not provide one."
)
samples = dist.mode
if samples.dim() == 2:
samples = rearrange(samples, "b f -> b f 1 1")
elif samples.dim() == 3:
samples = rearrange(samples, "b f ci -> b f ci 1")
else:
# Get mode of distribution as MPE
dist = self.distribution(with_differentiable_sampling=sampling_ctx.is_differentiable)
samples = rearrange(dist.mode, "f ci r -> 1 f ci r")
if samples.ndim == 4:
if not self.is_conditional:
samples = repeat(samples, "1 f ci r -> n f ci r", n=int(n_samples.item())).detach()
# repetition_index shape: (n_samples,)
r_idx = sampling_ctx.repetition_index[instance_mask]
num_features = samples.shape[1]
num_channels = samples.shape[2]
r_idx = repeat_repetition_index(
repetition_index=r_idx,
pattern="b r -> b f c r",
f=num_features,
c=num_channels,
)
samples = index_tensor(
samples, index=r_idx, dim=-1, is_differentiable=sampling_ctx.is_differentiable
)
elif samples.ndim != 4:
raise ValueError("The samples are not 4-dimensional. This should not happen.")
else:
if not self.is_conditional:
samples = repeat(samples, "1 f ci -> n f ci", n=int(n_samples.item())).detach()
else:
if self.is_conditional:
# Get evidence
evidence = data[instance_mask][:, self.scope.evidence]
dist = self.conditional_distribution(
evidence=evidence,
with_differentiable_sampling=sampling_ctx.is_differentiable,
)
# Distribution parameters already contain batch dim, therefore sample shape is (1,)
if getattr(dist, "has_rsample", False):
samples = dist.rsample((1,)).squeeze(0)
else:
samples = dist.sample((1,)).squeeze(0)
else:
# Sample n_samples from distribution
dist = self.distribution(with_differentiable_sampling=sampling_ctx.is_differentiable)
if getattr(dist, "has_rsample", False):
samples = dist.rsample((n_samples,))
else:
samples = dist.sample((n_samples,))
# repetition_index shape: (n_samples,)
r_idx = sampling_ctx.repetition_index[instance_mask]
num_features = samples.shape[1]
num_channels = samples.shape[2]
r_idx = repeat_repetition_index(
repetition_index=r_idx,
pattern="b r -> b f c r",
f=num_features,
c=num_channels,
)
# Index into samples with r_idx to select the correct repetition for each sample
samples = index_tensor(
samples, index=r_idx, dim=-1, is_differentiable=sampling_ctx.is_differentiable
)
if samples.shape[0] != sampling_ctx.channel_index[instance_mask].shape[0]:
raise ValueError(
f"Sample shape mismatch: got {samples.shape[0]}, expected {sampling_ctx.channel_index[instance_mask].shape[0]}"
)
c_idx_active = c_idx[instance_mask]
c_idx = c_idx_active
if not sampling_ctx.is_differentiable:
c_idx = rearrange(c_idx, "b f -> b f 1")
samples = index_tensor(
tensor=samples, index=c_idx, dim=2, is_differentiable=sampling_ctx.is_differentiable
)
if sampling_ctx.return_leaf_params:
self._collect_leaf_param_record(
sampling_ctx=sampling_ctx,
instance_mask=instance_mask,
scope_cols=scope_cols,
samples_mask=samples_mask,
channel_index_active=c_idx_active,
evidence=evidence,
)
# Ensure, that no data is overwritten
if data[samples_mask].isfinite().any():
raise RuntimeError("Data already contains values at the specified mask. This should not happen.")
# Update data inplace - place samples at correct scope positions (vectorized)
# samples[:, feat_idx] should go to data[:, scope.query[feat_idx]]
# Only write where the mask is True for that specific position
# Get row indices for instances that need sampling
row_indices = instance_mask.nonzero(as_tuple=True)[0] # (n_instances,)
# Create scope indices tensor
scope_idx = torch.tensor(scope_cols, dtype=torch.long, device=data.device)
# Expand to create all (row, col) index pairs
# rows: (n_instances, out_features) - row index repeated for each feature
# cols: (n_instances, out_features) - scope indices repeated for each instance
num_scope_features = len(scope_idx)
num_instances = int(n_samples.item())
rows = repeat(row_indices, "n -> n s", s=num_scope_features)
cols = repeat(scope_idx, "s -> n s", n=num_instances)
# Get mask subset for scope positions only
mask_subset = samples_mask[instance_mask][:, scope_cols] # (n_instances, out_features)
# Apply mask and flatten for single vectorized assignment
data[rows[mask_subset], cols[mask_subset]] = samples[mask_subset].to(data.dtype)
return data
def _normalize_param_tensor_for_routing(self, param: Tensor, n_active: int) -> Tensor:
"""Normalize parameter tensor to shape (N_active, F, C, R, *tail)."""
if param.dim() < 3:
raise ShapeError(
f"Leaf parameter tensors must have rank >= 3 (features, channels, repetitions), got rank {param.dim()}."
)
num_features = self.out_shape.features
num_channels = self.out_shape.channels
num_repetitions = self.out_shape.repetitions
if param.dim() >= 4 and param.shape[0] == n_active and param.shape[1] == num_features:
routed = param
elif param.shape[0] == num_features:
if param.dim() == 3:
routed = repeat(param, "f c r -> n f c r", n=n_active)
else:
routed = repeat(param, "f c r ... -> n f c r ...", n=n_active)
else:
raise ShapeError(
"Leaf parameter tensor has incompatible leading dimensions: "
f"got {tuple(param.shape)}, expected ({num_features}, C, R, ...) or "
f"({n_active}, {num_features}, C, R, ...)."
)
if routed.dim() < 4:
raise ShapeError(
f"Leaf parameter tensor must include channels and repetitions after normalization, got {tuple(routed.shape)}."
)
if routed.shape[2] not in (1, num_channels):
raise ShapeError(
f"Leaf parameter tensor channel dimension mismatch: got {routed.shape[2]}, expected 1 or {num_channels}."
)
if routed.shape[3] not in (1, num_repetitions):
raise ShapeError(
f"Leaf parameter tensor repetition dimension mismatch: got {routed.shape[3]}, expected 1 or {num_repetitions}."
)
if routed.shape[2] == 1 and num_channels > 1:
if routed.dim() == 4:
routed = repeat(routed, "n f 1 r -> n f c r", c=num_channels)
else:
routed = repeat(routed, "n f 1 r ... -> n f c r ...", c=num_channels)
if routed.shape[3] == 1 and num_repetitions > 1:
if routed.dim() == 4:
routed = repeat(routed, "n f c 1 -> n f c r", r=num_repetitions)
else:
routed = repeat(routed, "n f c 1 ... -> n f c r ...", r=num_repetitions)
return routed
def _select_param_repetition(
self,
*,
params: Tensor,
sampling_ctx: SamplingContext,
instance_mask: Tensor,
) -> Tensor:
"""Select routed repetition dimension, yielding shape (N, F, C, *tail)."""
params_flat = rearrange(params, "n f c r ... -> n f c r (...)")
tail_shape = params.shape[4:]
if sampling_ctx.is_differentiable:
repetition_weights = sampling_ctx.repetition_index[instance_mask]
if repetition_weights.dim() != 2 or repetition_weights.shape[1] != params.shape[3]:
raise ShapeError(
"Differentiable repetition routing has incompatible shape for leaf parameters: "
f"got {tuple(repetition_weights.shape)}, expected ({params.shape[0]}, {params.shape[3]})."
)
one_hot_sums = repetition_weights.sum(dim=1)
if not torch.allclose(one_hot_sums, torch.ones_like(one_hot_sums), rtol=0.0, atol=1e-6):
raise ShapeError(
"Differentiable repetition routing for leaf parameters must be one-hot encoded per sample."
)
repetition_weights = repetition_weights.to(device=params.device, dtype=params.dtype)
repetition_weights = rearrange(repetition_weights, "n r -> n 1 1 r 1")
selected_flat = index_one_hot(params_flat, index=repetition_weights, dim=3)
return selected_flat.reshape(params.shape[0], params.shape[1], params.shape[2], *tail_shape)
repetition_index = sampling_ctx.repetition_index[instance_mask]
if repetition_index.dim() == 2:
if repetition_index.shape[1] != 1:
raise ShapeError(
"Non-differentiable repetition routing must be shape (batch,) or (batch, 1), "
f"got {tuple(repetition_index.shape)}."
)
repetition_index = repetition_index[:, 0]
repetition_index = repetition_index.to(device=params.device, dtype=torch.long)
repetition_weights = torch.nn.functional.one_hot(
repetition_index,
num_classes=params.shape[3],
).to(dtype=params.dtype)
repetition_weights = rearrange(repetition_weights, "n r -> n 1 1 r 1")
selected_flat = index_one_hot(params_flat, index=repetition_weights, dim=3)
return selected_flat.reshape(params.shape[0], params.shape[1], params.shape[2], *tail_shape)
def _select_param_channel(
self, *, params: Tensor, channel_index: Tensor, is_differentiable: bool
) -> Tensor:
"""Select routed channel dimension, yielding shape (N, F, *tail)."""
params_flat = rearrange(params, "n f c ... -> n f c (...)")
tail_shape = params.shape[3:]
if is_differentiable:
channel_weights = channel_index
if channel_weights.dim() != 3 or channel_weights.shape[2] != params.shape[2]:
raise ShapeError(
"Differentiable channel routing has incompatible shape for leaf parameters: "
f"got {tuple(channel_weights.shape)}, expected ({params.shape[0]}, {params.shape[1]}, {params.shape[2]})."
)
if channel_weights.shape[1] == 1 and params.shape[1] > 1:
channel_weights = repeat(channel_weights, "n 1 c -> n f c", f=params.shape[1])
elif channel_weights.shape[1] != params.shape[1]:
raise ShapeError(
"Differentiable channel routing feature width mismatch for leaf parameters: "
f"got {channel_weights.shape[1]}, expected 1 or {params.shape[1]}."
)
one_hot_sums = channel_weights.sum(dim=2)
if not torch.allclose(one_hot_sums, torch.ones_like(one_hot_sums), rtol=0.0, atol=1e-6):
raise ShapeError(
"Differentiable channel routing for leaf parameters must be one-hot encoded per (sample, feature)."
)
channel_weights = channel_weights.to(device=params.device, dtype=params.dtype)
channel_weights = rearrange(channel_weights, "n f c -> n f c 1")
selected_flat = index_one_hot(params_flat, index=channel_weights, dim=2)
return selected_flat.reshape(params.shape[0], params.shape[1], *tail_shape)
if channel_index.dim() != 2:
raise ShapeError(
"Non-differentiable channel routing for leaf parameters expects rank-2 indices, "
f"got rank {channel_index.dim()}."
)
gather_index = channel_index.to(device=params.device, dtype=torch.long)
if gather_index.shape[1] == 1 and params.shape[1] > 1:
gather_index = repeat(gather_index, "n 1 -> n f", f=params.shape[1])
elif gather_index.shape[1] != params.shape[1]:
raise ShapeError(
"Channel routing feature width mismatch for leaf parameters: "
f"got {gather_index.shape[1]}, expected 1 or {params.shape[1]}."
)
channel_weights = torch.nn.functional.one_hot(gather_index, num_classes=params.shape[2]).to(
dtype=params.dtype,
device=params.device,
)
channel_weights = rearrange(channel_weights, "n f c -> n f c 1")
selected_flat = index_one_hot(params_flat, index=channel_weights, dim=2)
return selected_flat.reshape(params.shape[0], params.shape[1], *tail_shape)
def _collect_leaf_param_record(
self,
*,
sampling_ctx: SamplingContext,
instance_mask: Tensor,
scope_cols: list[int],
samples_mask: Tensor,
channel_index_active: Tensor,
evidence: Tensor | None,
) -> None:
"""Collect routed leaf parameters into the sampling context."""
n_active = int(instance_mask.sum().item())
if n_active == 0:
return
if self.is_conditional:
if evidence is None:
raise RuntimeError("Conditional leaf parameter collection requires evidence.")
params_dict = self.parameter_fn(evidence)
else:
params_dict = self.params()
if "scale" in params_dict and "log_scale" not in params_dict:
log_scale_tensor = getattr(self, "log_scale", None)
if isinstance(log_scale_tensor, Tensor):
params_dict = dict(params_dict)
params_dict["log_scale"] = log_scale_tensor
active_mask = samples_mask[:, scope_cols].clone()
routed_params: dict[str, Tensor] = {}
for key, value in params_dict.items():
if not isinstance(value, Tensor):
raise UnsupportedOperationError(
f"Leaf parameter '{key}' must be a Tensor when return_leaf_params=True, got {type(value)}."
)
params = self._normalize_param_tensor_for_routing(value, n_active=n_active)
params = self._select_param_repetition(
params=params,
sampling_ctx=sampling_ctx,
instance_mask=instance_mask,
)
params = self._select_param_channel(
params=params,
channel_index=channel_index_active,
is_differentiable=sampling_ctx.is_differentiable,
)
full_shape = (int(instance_mask.shape[0]), len(scope_cols), *params.shape[2:])
full_params = params.new_zeros(full_shape)
full_params[instance_mask] = params
routed_params[key] = full_params
sampling_ctx.add_leaf_param_record(
LeafParamRecord(
leaf_id=id(self),
leaf_type=self.__class__.__name__,
scope_cols=tuple(scope_cols),
active_mask=active_mask,
params=routed_params,
)
)
def _resolve_scope_columns(self, num_features: int) -> list[int]:
"""Resolve the column indices in `data` that correspond to this leaf's scope.
The sampling API can operate on either:
- "Global" feature tensors where columns are indexed by RV id.
- "Scoped" feature tensors where columns correspond exactly to this module's scope ordering.
Args:
num_features: Number of feature columns in the provided data tensor.
Returns:
List of column indices into the provided data tensor that correspond to `self.scope.query`.
"""
query = list(self.scope.query)
if len(query) == 0:
return []
if all(0 <= rv < num_features for rv in query):
return query
if num_features == self.out_shape.features:
return list(range(num_features))
raise ShapeError(
f"Cannot map scope {self.scope} to data with {num_features} features; "
f"expected either all RV ids < {num_features} or a scoped tensor with {self.out_shape.features} features."
)
def _slice_sampling_context(
self, sampling_ctx: SamplingContext, num_features: int, scope_cols: list[int]
) -> tuple[Tensor, Tensor]:
"""Slice/expand sampling context tensors to align with this leaf's feature axis.
Args:
sampling_ctx: Sampling context provided to sampling.
num_features: Number of feature columns in the provided data tensor.
scope_cols: Column indices (into data) that correspond to this leaf's scope.
Returns:
Tuple of (channel_index, mask), both shaped (num_samples, len(scope_cols)).
Raises:
ShapeError: If the context tensors cannot be aligned to the leaf scope.
"""
ctx_features = sampling_ctx.mask.shape[1]
scope_size = len(scope_cols)
if ctx_features == scope_size:
return sampling_ctx.channel_index, sampling_ctx.mask
if ctx_features == 1:
if sampling_ctx.is_differentiable:
return (
repeat(sampling_ctx.channel_index, "b 1 c -> b f c", f=scope_size),
repeat(sampling_ctx.mask, "b 1 -> b f", f=scope_size),
)
return (
repeat(sampling_ctx.channel_index, "b 1 -> b f", f=scope_size),
repeat(sampling_ctx.mask, "b 1 -> b f", f=scope_size),
)
if ctx_features == num_features:
return sampling_ctx.channel_index[:, scope_cols], sampling_ctx.mask[:, scope_cols]
raise ShapeError(
"SamplingContext feature dimension mismatch: "
f"got {ctx_features}, expected {scope_size} or {num_features}."
)
[docs]
def marginalize(
self,
marg_rvs: list[int],
prune: bool = True,
cache: Cache | None = None,
) -> Optional["LeafModule"]:
"""Structurally marginalize specified variables.
Args:
marg_rvs: Variable indices to marginalize.
prune: Unused (for interface consistency).
cache: Optional cache dictionary.
Returns:
Marginalized leaf or None if fully marginalized.
"""
if self.is_conditional:
raise RuntimeError(
f"Marginalization not supported for conditional leaf {self.__class__.__name__}."
)
# Marginalized scope
scope_marg = Scope([q for q in self.scope.query if q not in marg_rvs])
# Get indices of marginalized random variables in the original scope
idxs_marg = [i for i, q in enumerate(self.scope.query) if q in scope_marg.query]
if len(scope_marg.query) == 0:
return None
# Construct new leaves with marginalized scope and params
marg_params_dict = self.marginalized_params(idxs_marg)
# Make sure to detach the parameters first
marg_params_dict = {k: v.detach() for k, v in marg_params_dict.items()}
# Construct new object of the same class as the leaves
return self.__class__(
scope=scope_marg,
**marg_params_dict,
)