Source code for spflow.modules.leaves.leaf

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Optional, Dict, Callable, Iterable, Union

import numpy as np
import torch
from einops import rearrange, repeat
from torch import Tensor

from spflow.exceptions import MissingCacheError, ShapeError, UnsupportedOperationError
from spflow.meta.data.interval_evidence import IntervalEvidence
from spflow.meta.data.scope import Scope
from spflow.modules.module import Module
from spflow.modules.module_shape import ModuleShape
from spflow.utils.cache import Cache, cached
from spflow.utils.leaves import apply_nan_strategy, parse_leaf_args
from spflow.utils.sampling_context import LeafParamRecord, SamplingContext
from spflow.utils.sampling_context import index_one_hot, index_tensor, repeat_repetition_index


[docs] class LeafModule(Module, ABC):
[docs] def __init__( self, scope: Scope | int | Iterable[int], out_channels: int = 1, num_repetitions: int = 1, params: list[Tensor | None] | None = None, parameter_fn: Callable[[Tensor], dict[str, Tensor]] = None, validate_args: bool | None = True, ): """Base class for leaf distribution modules. Args: scope: Variable scope. Can be a Scope object, a single integer, or an iterable of integers (list, tuple, numpy array, torch tensor, etc.). out_channels: Number of output channels (inferred from params if None). num_repetitions: Number of repetitions (for 3D event shapes). params: List of parameter tensors (can include None to trigger random init). parameter_fn: Optional function that takes evidence and returns distribution parameters as dictionary. validate_args: Whether to enable torch.distributions argument validation. """ super().__init__() event_shape = parse_leaf_args( scope=scope, out_channels=out_channels, num_repetitions=num_repetitions, params=params, ) # If not already a Scope, convert int or list[int] to Scope if not isinstance(scope, Scope): scope = Scope(scope) self.scope = scope.copy() self._event_shape = event_shape self.parameter_fn = parameter_fn self._validate_args = validate_args # Shape computation from _event_shape features = self._event_shape[0] channels = self._event_shape[1] if len(self._event_shape) > 1 else 1 repetitions = self._event_shape[2] if len(self._event_shape) > 2 else 1 self.in_shape = ModuleShape(features, 1, 1) self.out_shape = ModuleShape(features, channels, repetitions)
@property def inputs(self) -> Module | Iterable[Module]: """Leaf modules do not have inputs.""" raise AttributeError( "LeafModule does not have 'input' attribute -- this should not have been called." ) @inputs.setter def inputs(self, value): """Leaf modules do not have inputs.""" raise AttributeError( "LeafModule does not have 'input' attribute -- this should not have been called." ) @property def is_conditional(self): """Indicates if the leaf uses a parameter network for conditional parameters.""" return self.parameter_fn is not None
[docs] def distribution(self, with_differentiable_sampling: bool = False) -> torch.distributions.Distribution: """Return this leaf's distribution. Args: with_differentiable_sampling: Hook for subclasses to return an alternative differentiable distribution when sampling requires gradient flow. Ignored by the base implementation. """ return self.__make_distribution( self.params(), with_differentiable_sampling=with_differentiable_sampling )
@property @abstractmethod def _torch_distribution_class(self) -> type[torch.distributions.Distribution]: pass @property def _torch_distribution_class_with_differentiable_sampling( self, ) -> type[torch.distributions.Distribution]: raise NotImplementedError( f"{self.__class__.__name__} does not implement an alternative distribution with differentiable sampling. " f"Override _torch_distribution_class_with_differentiable_sampling or set with_differentiable_sampling=False when calling distribution()." ) def __make_distribution( self, params: Dict[str, Tensor], with_differentiable_sampling: bool = False ) -> torch.distributions.Distribution: """Helper method to create distribution from given parameters. Args: params: Dictionary of distribution parameters. with_differentiable_sampling: Whether to use the differentiable sampling distribution class. Returns: torch.distributions.Distribution constructed from the parameters. """ if with_differentiable_sampling: return self._torch_distribution_class_with_differentiable_sampling(validate_args=self._validate_args, **params) # type: ignore[call-arg] else: return self._torch_distribution_class(validate_args=self._validate_args, **params) # type: ignore[call-arg]
[docs] def conditional_distribution( self, evidence: Tensor, with_differentiable_sampling: bool = False ) -> torch.distributions.Distribution: """Generate a conditional distribution from evidence. Args: evidence: Evidence tensor for conditioning. with_differentiable_sampling: Hook for subclasses to return an alternative distribution with differentiable sampling when needed. Ignored by the base implementation. Returns: torch.distributions.Distribution constructed from conditional parameters. """ if evidence is None: raise ValueError("Evidence tensor must be provided for conditional distribution.") params = self.parameter_fn(evidence) return self.__make_distribution(params, with_differentiable_sampling=with_differentiable_sampling)
@property @abstractmethod def _supported_value(self) -> Union[float, Tensor]: """Return a value in the support of the distribution (for NaN imputation). Implementations may return either a scalar float or a tensor. If a tensor is returned, it must either be a scalar tensor or have a leading feature dimension that matches ``len(self.scope.query)``. """ pass def _supported_value_for_imputation(self, scoped_data: Tensor) -> Tensor: """Return a tensor of supported values broadcastable to ``scoped_data``. This helper normalizes leaf-specific ``_supported_value`` implementations so that NaN imputation works for leaves whose parameters are stored with additional channel/repetition dimensions. Args: scoped_data: 2D tensor of shape (batch, features) for the leaf scope. Returns: Tensor broadcasted to the shape of ``scoped_data``. Raises: ShapeError: If the leaf provides a tensor that cannot be aligned to features. """ if scoped_data.dim() != 2: raise ShapeError( f"Expected scoped_data to be 2D (batch, features), got shape {tuple(scoped_data.shape)}." ) supported_value = self._supported_value if isinstance(supported_value, (float, int)): return torch.full_like(scoped_data, float(supported_value)) if not isinstance(supported_value, Tensor): raise TypeError( f"{self.__class__.__name__}._supported_value must be a float or Tensor, got {type(supported_value)}." ) supported_tensor = supported_value.to(device=scoped_data.device, dtype=scoped_data.dtype) batch_size = scoped_data.shape[0] num_features = scoped_data.shape[1] if supported_tensor.numel() == 1: return repeat(supported_tensor.reshape(()), "-> b f", b=batch_size, f=num_features) if supported_tensor.dim() == 1: if supported_tensor.shape[0] != num_features: raise ShapeError( f"{self.__class__.__name__}._supported_value has shape {tuple(supported_tensor.shape)}; " f"expected ({num_features},)." ) return repeat(supported_tensor, "f -> b f", b=batch_size) if supported_tensor.shape[0] != num_features: raise ShapeError( f"{self.__class__.__name__}._supported_value has shape {tuple(supported_tensor.shape)}; " f"expected first dimension to match number of features ({num_features})." ) feature_values = supported_tensor while feature_values.dim() > 1: feature_values = feature_values.select(dim=1, index=0) if feature_values.shape != (num_features,): raise ShapeError( f"{self.__class__.__name__}._supported_value could not be reduced to per-feature values; " f"got shape {tuple(feature_values.shape)}." ) return repeat(feature_values, "f -> b f", b=batch_size)
[docs] @abstractmethod def params(self) -> Dict[str, Tensor]: """Returns the parameters of the distribution.""" pass
[docs] def mode(self, is_differentiable: bool = False) -> Tensor: """Return distribution mode. Args: is_differentiable: Whether to return the mode from the differentiable distribution (if supported). Returns: Mode of the distribution. """ return self.distribution().mode
[docs] def marginalized_params(self, indices: list[int]) -> dict[str, Tensor]: """Return parameters marginalized to specified indices. Args: indices: List of indices to marginalize to. Returns: Dictionary of marginalized parameters. """ return {k: v[indices] for k, v in self.params().items()}
def _mle_update_statistics(self, data: Tensor, weights: Tensor, bias_correction: bool) -> None: """Compute and set MLE parameter estimates. Args: data: Input data tensor. weights: Weight tensor for each data point. bias_correction: Whether to apply bias correction to variance estimate. """ data = rearrange(data, "b f -> b f 1 1") estimates = self._compute_parameter_estimates(data, weights, bias_correction) self._set_mle_parameters(estimates) @abstractmethod def _compute_parameter_estimates( self, data: Tensor, weights: Tensor, bias_correction: bool ) -> Dict[str, Tensor]: """Compute raw MLE parameter estimates without broadcasting. Used internally by both simple and KMeans clustering paths. Args: data: Scope-filtered data. weights: Normalized weights. bias_correction: Apply bias correction. Returns: Dictionary mapping parameter names to raw estimates (shape: out_features). """ pass def _set_mle_parameters(self, params_dict: Dict[str, Tensor]) -> None: """Set MLE-estimated parameters. This method handles the assignment of estimated parameters, accounting for both direct nn.Parameter objects and property-based parameters with custom setters. Args: params_dict: Dictionary mapping parameter names to their estimated values. """ for param_name, param_tensor in params_dict.items(): try: # Try using property setter (works for properties like 'scale') setattr(self, param_name, param_tensor) except TypeError: # Direct parameter (like 'loc') - update .data attribute getattr(self, param_name).data = param_tensor @property def event_shape(self) -> tuple[int, ...]: """Return event shape. Returns: Event shape tuple. """ if self._event_shape is None: raise RuntimeError(f"{self.__class__.__name__} has not set _event_shape in __init__") return self._event_shape @property def feature_to_scope(self) -> np.ndarray[Scope]: """Return list of scopes per feature. Returns: List of Scope objects, one per feature. """ scopes = np.empty((self.out_shape.features, self.out_shape.repetitions), dtype=Scope) for i in range(self.out_shape.features): for j in range(self.out_shape.repetitions): scopes[i, j] = Scope([self.scope.query[i]]) return scopes @property def device(self) -> torch.device: """Return device of first parameter or buffer. Returns: Device of the module. """ try: return next(iter(self.parameters())).device except StopIteration: return next(iter(self.buffers())).device def _broadcast_to_event_shape(self, param_est: Tensor) -> Tensor: """Broadcast parameter estimate to match event_shape. Args: param_est: Parameter estimate tensor to broadcast. Returns: Parameter estimate broadcasted to match event_shape. """ target_shape = tuple(self.event_shape) # If the parameter already matches the event shape (possibly with extra trailing dims), # there is nothing to broadcast. This prevents unsqueezing again during chained calls. if tuple(param_est.shape[: len(target_shape)]) == target_shape: return param_est if len(target_shape) == 2: param_est = rearrange(param_est, "f ... -> f 1 ...") param_est = param_est.repeat(1, self.out_shape.channels, *([1] * (param_est.dim() - 2))) elif len(target_shape) == 3: param_est = rearrange(param_est, "f ... -> f 1 1 ...") param_est = param_est.repeat( 1, self.out_shape.channels, self.out_shape.repetitions, *([1] * (param_est.dim() - 3)), ) return param_est def _expectation_maximization_step( self, data: torch.Tensor, bias_correction: bool = True, *, cache: Cache, ) -> None: """Perform single EM step. Args: data: Input data tensor. bias_correction: Whether to apply bias correction for leaf statistics. cache: Cache dictionary from a preceding forward pass. """ # Fixed-parameter leaves (no learnable parameters) are explicit EM no-ops. if not any(param.requires_grad for param in self.parameters()): return with torch.no_grad(): # ----- expectation step ----- # get cached log-likelihood gradients w.r.t. module log-likelihoods module_lls = cache["log_likelihood"].get(self) if module_lls is None: raise MissingCacheError( "Module log-likelihoods not found in cache. Call log_likelihood first." ) expectations = module_lls.grad if expectations is None: raise RuntimeError( f"Expected gradient for cached log-likelihood tensor of {self.__class__.__name__}, but found None." ) expectations += 1e-12 # numerical stability expectations /= expectations.sum(0, keepdim=True) # Normalize # ----- maximization step ----- # update parameters through maximum weighted likelihood estimation self.maximum_likelihood_estimation( data, weights=expectations, bias_correction=bias_correction, ) # NOTE: since we explicitely override parameters in 'maximum_likelihood_estimation', # we do not need to zero/None parameter gradients
[docs] @cached def log_likelihood( self, data: Tensor | IntervalEvidence, cache: Cache | None = None, ) -> Tensor: """Compute log-likelihoods, marginalizing over NaN values. Args: data: Input data tensor or IntervalEvidence for range queries. cache: Optional cache dictionary. Returns: Log-likelihood tensor. """ # Dispatch on IntervalEvidence if isinstance(data, IntervalEvidence): return self._log_likelihood_interval(data.low, data.high, cache) if data.dim() != 2: raise ValueError(f"Data must be 2-dimensional (batch, num_features), got shape {data.shape}.") # get information relevant for the scope data_q = data[:, self.scope.query].to(device=self.device).clone() if self.event_shape[0] != len(self.scope.query): raise RuntimeError( f"event_shape mismatch for {self.__class__.__name__}: event_shape={self.event_shape}, scope_len={len(self.scope.query)}" ) # ----- marginalization ----- marg_mask = torch.isnan(data_q) has_marginalizations = marg_mask.any() # If there are any marg_ids, set them to 0.0 to ensure that log_prob call is successful # and doesn't throw errors due to NaNs if has_marginalizations: fill_values = self._supported_value_for_imputation(data_q) data_q[marg_mask] = fill_values[marg_mask] # ----- log probabilities ----- # Add broadcast dims for channel and repetition axes. # event_shape is always [features, out_channels, num_repetitions]. data_q = rearrange(data_q, "b f -> b f 1 1") if self.is_conditional: # Get evidence data_e = data[:, self.scope.evidence].to(device=self.device) dist = self.conditional_distribution(data_e) else: dist = self.distribution() log_prob = dist.log_prob(data_q) # Marginalize entries - broadcast mask to log_prob shape if has_marginalizations: # Expand marg_mask to match log_prob shape by broadcasting # marg_mask is [batch, features], expand to [batch, features, 1, 1] marg_mask_for_log_prob = rearrange(marg_mask, "b f -> b f 1 1") # Broadcast to log_prob shape marg_mask_for_log_prob = torch.broadcast_to(marg_mask_for_log_prob, log_prob.shape) log_prob[marg_mask_for_log_prob] = 0.0 # Set marginalized scope data back to NaNs if has_marginalizations: marg_mask_for_data = rearrange(marg_mask, "b f -> b f 1 1") data_q[marg_mask_for_data] = torch.nan return log_prob
def _log_likelihood_interval( self, low: Tensor, high: Tensor, cache: Cache | None = None, ) -> Tensor: """Compute log P(low <= X <= high) for interval evidence. Uses the distribution's CDF: P(low <= X <= high) = CDF(high) - CDF(low). Subclasses can override this for custom implementations. Args: low: Lower bounds of shape (batch, features). NaN = no lower bound (-inf). high: Upper bounds of shape (batch, features). NaN = no upper bound (+inf). cache: Optional cache dictionary. Returns: Log-likelihood tensor. Raises: NotImplementedError: If the distribution doesn't support CDF. """ # Get scope-filtered bounds low_scoped = low[:, self.scope.query].to(device=self.device) high_scoped = high[:, self.scope.query].to(device=self.device) # Expand to match (batch, features, channels, repetitions) low_expanded = rearrange(low_scoped, "b f -> b f 1 1") high_expanded = rearrange(high_scoped, "b f -> b f 1 1") # Get the distribution dist = self.distribution() # Check if distribution has cdf method if not hasattr(dist, "cdf"): raise NotImplementedError( f"{self.__class__.__name__} does not support interval inference. " f"The underlying distribution {dist.__class__.__name__} has no cdf() method." ) # Handle NaN bounds as -inf/+inf low_for_cdf = torch.where( torch.isnan(low_expanded), torch.full_like(low_expanded, float("-inf")), low_expanded, ) high_for_cdf = torch.where( torch.isnan(high_expanded), torch.full_like(high_expanded, float("inf")), high_expanded, ) # Compute CDF values try: cdf_high = dist.cdf(high_for_cdf) cdf_low = dist.cdf(low_for_cdf) except NotImplementedError: raise NotImplementedError( f"{self.__class__.__name__} does not support interval inference. " f"The underlying distribution {dist.__class__.__name__} does not implement cdf()." ) # P(low <= X <= high) = CDF(high) - CDF(low) prob = torch.clamp(cdf_high - cdf_low, min=1e-40) # Numerical stability return torch.log(prob) def _prepare_mle_data( self, data: Tensor, weights: Tensor | None = None, nan_strategy: str | Callable | None = None, ) -> tuple[Tensor, Tensor]: """Prepare normalized data and weights for MLE computation. Args: data: Input data tensor. weights: Optional sample weights. nan_strategy: Handle NaN ('ignore', callable, or None). Returns: Scope-filtered data and normalized weights. """ # Step 1: Select scope-relevant features scoped_data = data[:, self.scope.query] if weights is None: weights = torch.ones( scoped_data.shape[0], self.out_shape.features, self.out_shape.channels, self.out_shape.repetitions, device=self.device, ) # Step 2: Apply NaN strategy (drop/impute) scoped_data, normalized_weights = apply_nan_strategy(nan_strategy, scoped_data, self.device, weights) return scoped_data, normalized_weights
[docs] def maximum_likelihood_estimation( self, data: Tensor, weights: Optional[Tensor] = None, bias_correction: bool = True, nan_strategy: str | Callable | None = None, ) -> None: """Maximum (weighted) likelihood estimation via template method pattern. Delegates distribution-specific logic to _mle_compute_statistics() hook. Weights normalized to sum to N. Args: data: Input data tensor. weights: Optional sample weights. bias_correction: Apply bias correction. nan_strategy: Handle NaN ('ignore', callable, or None). """ if self.is_conditional: raise RuntimeError(f"MLE not supported for conditional leaf {self.__class__.__name__}.") # Step 1: Prepare normalized data and weights data_prepared, weights_prepared = self._prepare_mle_data( data=data, weights=weights, nan_strategy=nan_strategy, ) # Step 2: Update distribution-specific statistics self._mle_update_statistics(data_prepared, weights_prepared, bias_correction)
def _sample( self, data: Tensor, sampling_ctx: SamplingContext, cache: Cache, ) -> Tensor: """Sample from leaf distribution given potential evidence. Args: num_samples: Number of samples to generate. data: Optional evidence tensor. is_mpe: Perform MPE (mode) instead of sampling. cache: Optional cache dictionary. sampling_ctx: Optional sampling context. Returns: Sampled data tensor. """ # Prepare data tensor sampling_ctx.validate_sampling_context( num_samples=data.shape[0], num_features=self.out_shape.features, num_channels=self.out_shape.channels, num_repetitions=self.out_shape.repetitions, allowed_feature_widths=(1, self.out_shape.features, data.shape[1]), ) scope_cols = self._resolve_scope_columns(num_features=data.shape[1]) out_of_scope = [idx for idx in range(data.shape[1]) if idx not in scope_cols] marg_mask = torch.isnan(data) marg_mask[:, out_of_scope] = False # Mask that tells us which feature at which sample is relevant and should be sampled samples_mask = marg_mask c_idx, ctx_mask = self._slice_sampling_context( sampling_ctx=sampling_ctx, num_features=data.shape[1], scope_cols=scope_cols ) samples_mask[:, scope_cols] &= ctx_mask # Count number of samples to draw instance_mask = samples_mask.sum(1) > 0 n_samples = instance_mask.sum() # count number of rows which have at least one true value # Routing can legitimately send zero rows to a branch. In that case, there is nothing to # sample for this leaf and we should return without touching the data tensor. if int(n_samples.item()) == 0: return data evidence = None if self.is_conditional: evidence = data[instance_mask][:, self.scope.evidence] if sampling_ctx.is_mpe: if self.is_conditional: dist = self.conditional_distribution( evidence=evidence, with_differentiable_sampling=sampling_ctx.is_differentiable, ) if not hasattr(dist, "mode"): raise UnsupportedOperationError( f"MPE sampling requires a distribution with a 'mode' attribute, but " f"{dist.__class__.__name__} does not provide one." ) samples = dist.mode if samples.dim() == 2: samples = rearrange(samples, "b f -> b f 1 1") elif samples.dim() == 3: samples = rearrange(samples, "b f ci -> b f ci 1") else: # Get mode of distribution as MPE dist = self.distribution(with_differentiable_sampling=sampling_ctx.is_differentiable) samples = rearrange(dist.mode, "f ci r -> 1 f ci r") if samples.ndim == 4: if not self.is_conditional: samples = repeat(samples, "1 f ci r -> n f ci r", n=int(n_samples.item())).detach() # repetition_index shape: (n_samples,) r_idx = sampling_ctx.repetition_index[instance_mask] num_features = samples.shape[1] num_channels = samples.shape[2] r_idx = repeat_repetition_index( repetition_index=r_idx, pattern="b r -> b f c r", f=num_features, c=num_channels, ) samples = index_tensor( samples, index=r_idx, dim=-1, is_differentiable=sampling_ctx.is_differentiable ) elif samples.ndim != 4: raise ValueError("The samples are not 4-dimensional. This should not happen.") else: if not self.is_conditional: samples = repeat(samples, "1 f ci -> n f ci", n=int(n_samples.item())).detach() else: if self.is_conditional: # Get evidence evidence = data[instance_mask][:, self.scope.evidence] dist = self.conditional_distribution( evidence=evidence, with_differentiable_sampling=sampling_ctx.is_differentiable, ) # Distribution parameters already contain batch dim, therefore sample shape is (1,) if getattr(dist, "has_rsample", False): samples = dist.rsample((1,)).squeeze(0) else: samples = dist.sample((1,)).squeeze(0) else: # Sample n_samples from distribution dist = self.distribution(with_differentiable_sampling=sampling_ctx.is_differentiable) if getattr(dist, "has_rsample", False): samples = dist.rsample((n_samples,)) else: samples = dist.sample((n_samples,)) # repetition_index shape: (n_samples,) r_idx = sampling_ctx.repetition_index[instance_mask] num_features = samples.shape[1] num_channels = samples.shape[2] r_idx = repeat_repetition_index( repetition_index=r_idx, pattern="b r -> b f c r", f=num_features, c=num_channels, ) # Index into samples with r_idx to select the correct repetition for each sample samples = index_tensor( samples, index=r_idx, dim=-1, is_differentiable=sampling_ctx.is_differentiable ) if samples.shape[0] != sampling_ctx.channel_index[instance_mask].shape[0]: raise ValueError( f"Sample shape mismatch: got {samples.shape[0]}, expected {sampling_ctx.channel_index[instance_mask].shape[0]}" ) c_idx_active = c_idx[instance_mask] c_idx = c_idx_active if not sampling_ctx.is_differentiable: c_idx = rearrange(c_idx, "b f -> b f 1") samples = index_tensor( tensor=samples, index=c_idx, dim=2, is_differentiable=sampling_ctx.is_differentiable ) if sampling_ctx.return_leaf_params: self._collect_leaf_param_record( sampling_ctx=sampling_ctx, instance_mask=instance_mask, scope_cols=scope_cols, samples_mask=samples_mask, channel_index_active=c_idx_active, evidence=evidence, ) # Ensure, that no data is overwritten if data[samples_mask].isfinite().any(): raise RuntimeError("Data already contains values at the specified mask. This should not happen.") # Update data inplace - place samples at correct scope positions (vectorized) # samples[:, feat_idx] should go to data[:, scope.query[feat_idx]] # Only write where the mask is True for that specific position # Get row indices for instances that need sampling row_indices = instance_mask.nonzero(as_tuple=True)[0] # (n_instances,) # Create scope indices tensor scope_idx = torch.tensor(scope_cols, dtype=torch.long, device=data.device) # Expand to create all (row, col) index pairs # rows: (n_instances, out_features) - row index repeated for each feature # cols: (n_instances, out_features) - scope indices repeated for each instance num_scope_features = len(scope_idx) num_instances = int(n_samples.item()) rows = repeat(row_indices, "n -> n s", s=num_scope_features) cols = repeat(scope_idx, "s -> n s", n=num_instances) # Get mask subset for scope positions only mask_subset = samples_mask[instance_mask][:, scope_cols] # (n_instances, out_features) # Apply mask and flatten for single vectorized assignment data[rows[mask_subset], cols[mask_subset]] = samples[mask_subset].to(data.dtype) return data def _normalize_param_tensor_for_routing(self, param: Tensor, n_active: int) -> Tensor: """Normalize parameter tensor to shape (N_active, F, C, R, *tail).""" if param.dim() < 3: raise ShapeError( f"Leaf parameter tensors must have rank >= 3 (features, channels, repetitions), got rank {param.dim()}." ) num_features = self.out_shape.features num_channels = self.out_shape.channels num_repetitions = self.out_shape.repetitions if param.dim() >= 4 and param.shape[0] == n_active and param.shape[1] == num_features: routed = param elif param.shape[0] == num_features: if param.dim() == 3: routed = repeat(param, "f c r -> n f c r", n=n_active) else: routed = repeat(param, "f c r ... -> n f c r ...", n=n_active) else: raise ShapeError( "Leaf parameter tensor has incompatible leading dimensions: " f"got {tuple(param.shape)}, expected ({num_features}, C, R, ...) or " f"({n_active}, {num_features}, C, R, ...)." ) if routed.dim() < 4: raise ShapeError( f"Leaf parameter tensor must include channels and repetitions after normalization, got {tuple(routed.shape)}." ) if routed.shape[2] not in (1, num_channels): raise ShapeError( f"Leaf parameter tensor channel dimension mismatch: got {routed.shape[2]}, expected 1 or {num_channels}." ) if routed.shape[3] not in (1, num_repetitions): raise ShapeError( f"Leaf parameter tensor repetition dimension mismatch: got {routed.shape[3]}, expected 1 or {num_repetitions}." ) if routed.shape[2] == 1 and num_channels > 1: if routed.dim() == 4: routed = repeat(routed, "n f 1 r -> n f c r", c=num_channels) else: routed = repeat(routed, "n f 1 r ... -> n f c r ...", c=num_channels) if routed.shape[3] == 1 and num_repetitions > 1: if routed.dim() == 4: routed = repeat(routed, "n f c 1 -> n f c r", r=num_repetitions) else: routed = repeat(routed, "n f c 1 ... -> n f c r ...", r=num_repetitions) return routed def _select_param_repetition( self, *, params: Tensor, sampling_ctx: SamplingContext, instance_mask: Tensor, ) -> Tensor: """Select routed repetition dimension, yielding shape (N, F, C, *tail).""" params_flat = rearrange(params, "n f c r ... -> n f c r (...)") tail_shape = params.shape[4:] if sampling_ctx.is_differentiable: repetition_weights = sampling_ctx.repetition_index[instance_mask] if repetition_weights.dim() != 2 or repetition_weights.shape[1] != params.shape[3]: raise ShapeError( "Differentiable repetition routing has incompatible shape for leaf parameters: " f"got {tuple(repetition_weights.shape)}, expected ({params.shape[0]}, {params.shape[3]})." ) one_hot_sums = repetition_weights.sum(dim=1) if not torch.allclose(one_hot_sums, torch.ones_like(one_hot_sums), rtol=0.0, atol=1e-6): raise ShapeError( "Differentiable repetition routing for leaf parameters must be one-hot encoded per sample." ) repetition_weights = repetition_weights.to(device=params.device, dtype=params.dtype) repetition_weights = rearrange(repetition_weights, "n r -> n 1 1 r 1") selected_flat = index_one_hot(params_flat, index=repetition_weights, dim=3) return selected_flat.reshape(params.shape[0], params.shape[1], params.shape[2], *tail_shape) repetition_index = sampling_ctx.repetition_index[instance_mask] if repetition_index.dim() == 2: if repetition_index.shape[1] != 1: raise ShapeError( "Non-differentiable repetition routing must be shape (batch,) or (batch, 1), " f"got {tuple(repetition_index.shape)}." ) repetition_index = repetition_index[:, 0] repetition_index = repetition_index.to(device=params.device, dtype=torch.long) repetition_weights = torch.nn.functional.one_hot( repetition_index, num_classes=params.shape[3], ).to(dtype=params.dtype) repetition_weights = rearrange(repetition_weights, "n r -> n 1 1 r 1") selected_flat = index_one_hot(params_flat, index=repetition_weights, dim=3) return selected_flat.reshape(params.shape[0], params.shape[1], params.shape[2], *tail_shape) def _select_param_channel( self, *, params: Tensor, channel_index: Tensor, is_differentiable: bool ) -> Tensor: """Select routed channel dimension, yielding shape (N, F, *tail).""" params_flat = rearrange(params, "n f c ... -> n f c (...)") tail_shape = params.shape[3:] if is_differentiable: channel_weights = channel_index if channel_weights.dim() != 3 or channel_weights.shape[2] != params.shape[2]: raise ShapeError( "Differentiable channel routing has incompatible shape for leaf parameters: " f"got {tuple(channel_weights.shape)}, expected ({params.shape[0]}, {params.shape[1]}, {params.shape[2]})." ) if channel_weights.shape[1] == 1 and params.shape[1] > 1: channel_weights = repeat(channel_weights, "n 1 c -> n f c", f=params.shape[1]) elif channel_weights.shape[1] != params.shape[1]: raise ShapeError( "Differentiable channel routing feature width mismatch for leaf parameters: " f"got {channel_weights.shape[1]}, expected 1 or {params.shape[1]}." ) one_hot_sums = channel_weights.sum(dim=2) if not torch.allclose(one_hot_sums, torch.ones_like(one_hot_sums), rtol=0.0, atol=1e-6): raise ShapeError( "Differentiable channel routing for leaf parameters must be one-hot encoded per (sample, feature)." ) channel_weights = channel_weights.to(device=params.device, dtype=params.dtype) channel_weights = rearrange(channel_weights, "n f c -> n f c 1") selected_flat = index_one_hot(params_flat, index=channel_weights, dim=2) return selected_flat.reshape(params.shape[0], params.shape[1], *tail_shape) if channel_index.dim() != 2: raise ShapeError( "Non-differentiable channel routing for leaf parameters expects rank-2 indices, " f"got rank {channel_index.dim()}." ) gather_index = channel_index.to(device=params.device, dtype=torch.long) if gather_index.shape[1] == 1 and params.shape[1] > 1: gather_index = repeat(gather_index, "n 1 -> n f", f=params.shape[1]) elif gather_index.shape[1] != params.shape[1]: raise ShapeError( "Channel routing feature width mismatch for leaf parameters: " f"got {gather_index.shape[1]}, expected 1 or {params.shape[1]}." ) channel_weights = torch.nn.functional.one_hot(gather_index, num_classes=params.shape[2]).to( dtype=params.dtype, device=params.device, ) channel_weights = rearrange(channel_weights, "n f c -> n f c 1") selected_flat = index_one_hot(params_flat, index=channel_weights, dim=2) return selected_flat.reshape(params.shape[0], params.shape[1], *tail_shape) def _collect_leaf_param_record( self, *, sampling_ctx: SamplingContext, instance_mask: Tensor, scope_cols: list[int], samples_mask: Tensor, channel_index_active: Tensor, evidence: Tensor | None, ) -> None: """Collect routed leaf parameters into the sampling context.""" n_active = int(instance_mask.sum().item()) if n_active == 0: return if self.is_conditional: if evidence is None: raise RuntimeError("Conditional leaf parameter collection requires evidence.") params_dict = self.parameter_fn(evidence) else: params_dict = self.params() if "scale" in params_dict and "log_scale" not in params_dict: log_scale_tensor = getattr(self, "log_scale", None) if isinstance(log_scale_tensor, Tensor): params_dict = dict(params_dict) params_dict["log_scale"] = log_scale_tensor active_mask = samples_mask[:, scope_cols].clone() routed_params: dict[str, Tensor] = {} for key, value in params_dict.items(): if not isinstance(value, Tensor): raise UnsupportedOperationError( f"Leaf parameter '{key}' must be a Tensor when return_leaf_params=True, got {type(value)}." ) params = self._normalize_param_tensor_for_routing(value, n_active=n_active) params = self._select_param_repetition( params=params, sampling_ctx=sampling_ctx, instance_mask=instance_mask, ) params = self._select_param_channel( params=params, channel_index=channel_index_active, is_differentiable=sampling_ctx.is_differentiable, ) full_shape = (int(instance_mask.shape[0]), len(scope_cols), *params.shape[2:]) full_params = params.new_zeros(full_shape) full_params[instance_mask] = params routed_params[key] = full_params sampling_ctx.add_leaf_param_record( LeafParamRecord( leaf_id=id(self), leaf_type=self.__class__.__name__, scope_cols=tuple(scope_cols), active_mask=active_mask, params=routed_params, ) ) def _resolve_scope_columns(self, num_features: int) -> list[int]: """Resolve the column indices in `data` that correspond to this leaf's scope. The sampling API can operate on either: - "Global" feature tensors where columns are indexed by RV id. - "Scoped" feature tensors where columns correspond exactly to this module's scope ordering. Args: num_features: Number of feature columns in the provided data tensor. Returns: List of column indices into the provided data tensor that correspond to `self.scope.query`. """ query = list(self.scope.query) if len(query) == 0: return [] if all(0 <= rv < num_features for rv in query): return query if num_features == self.out_shape.features: return list(range(num_features)) raise ShapeError( f"Cannot map scope {self.scope} to data with {num_features} features; " f"expected either all RV ids < {num_features} or a scoped tensor with {self.out_shape.features} features." ) def _slice_sampling_context( self, sampling_ctx: SamplingContext, num_features: int, scope_cols: list[int] ) -> tuple[Tensor, Tensor]: """Slice/expand sampling context tensors to align with this leaf's feature axis. Args: sampling_ctx: Sampling context provided to sampling. num_features: Number of feature columns in the provided data tensor. scope_cols: Column indices (into data) that correspond to this leaf's scope. Returns: Tuple of (channel_index, mask), both shaped (num_samples, len(scope_cols)). Raises: ShapeError: If the context tensors cannot be aligned to the leaf scope. """ ctx_features = sampling_ctx.mask.shape[1] scope_size = len(scope_cols) if ctx_features == scope_size: return sampling_ctx.channel_index, sampling_ctx.mask if ctx_features == 1: if sampling_ctx.is_differentiable: return ( repeat(sampling_ctx.channel_index, "b 1 c -> b f c", f=scope_size), repeat(sampling_ctx.mask, "b 1 -> b f", f=scope_size), ) return ( repeat(sampling_ctx.channel_index, "b 1 -> b f", f=scope_size), repeat(sampling_ctx.mask, "b 1 -> b f", f=scope_size), ) if ctx_features == num_features: return sampling_ctx.channel_index[:, scope_cols], sampling_ctx.mask[:, scope_cols] raise ShapeError( "SamplingContext feature dimension mismatch: " f"got {ctx_features}, expected {scope_size} or {num_features}." )
[docs] def marginalize( self, marg_rvs: list[int], prune: bool = True, cache: Cache | None = None, ) -> Optional["LeafModule"]: """Structurally marginalize specified variables. Args: marg_rvs: Variable indices to marginalize. prune: Unused (for interface consistency). cache: Optional cache dictionary. Returns: Marginalized leaf or None if fully marginalized. """ if self.is_conditional: raise RuntimeError( f"Marginalization not supported for conditional leaf {self.__class__.__name__}." ) # Marginalized scope scope_marg = Scope([q for q in self.scope.query if q not in marg_rvs]) # Get indices of marginalized random variables in the original scope idxs_marg = [i for i, q in enumerate(self.scope.query) if q in scope_marg.query] if len(scope_marg.query) == 0: return None # Construct new leaves with marginalized scope and params marg_params_dict = self.marginalized_params(idxs_marg) # Make sure to detach the parameters first marg_params_dict = {k: v.detach() for k, v in marg_params_dict.items()} # Construct new object of the same class as the leaves return self.__class__( scope=scope_marg, **marg_params_dict, )