"""Piecewise linear leaf distribution module.
This module provides a non-parametric density estimation approach that
approximates data distributions using piecewise linear functions constructed
from histograms. It uses K-means clustering to create multiple distributions
per leaf.
"""
from __future__ import annotations
import itertools
import logging
from typing import List, Optional
import torch
from einops import rearrange, repeat
from torch import Tensor, nn
from spflow.exceptions import OptionalDependencyError, ShapeError, UnsupportedOperationError
from spflow.meta.data.scope import Scope
from spflow.modules.leaves.leaf import LeafModule
from spflow.utils.cache import Cache
from spflow.utils.domain import DataType, Domain
from spflow.utils.histogram import get_bin_edges_torch
from spflow.utils.sampling_context import SamplingContext
logger = logging.getLogger(__name__)
def pairwise(iterable):
"""Iterate over consecutive pairs.
s -> (s0,s1), (s1,s2), (s2, s3), ...
"""
a, b = itertools.tee(iterable)
next(b, None)
return zip(a, b)
def interp(x: Tensor, xp: Tensor, fp: Tensor, dim: int = -1, extrapolate: str = "constant") -> Tensor:
"""One-dimensional linear interpolation between monotonically increasing sample points.
Returns the one-dimensional piecewise linear interpolant to a function with
given discrete data points (xp, fp), evaluated at x.
Source: https://github.com/pytorch/pytorch/issues/50334#issuecomment-2304751532
Args:
x: The x-coordinates at which to evaluate the interpolated values.
xp: The x-coordinates of the data points, must be increasing.
fp: The y-coordinates of the data points, same shape as xp.
dim: Dimension across which to interpolate.
extrapolate: How to handle values outside the range of xp. Options:
- 'linear': Extrapolate linearly beyond range.
- 'constant': Use boundary value of fp for x outside xp.
Returns:
The interpolated values, same size as x.
"""
# Move the interpolation dimension to the last axis
x = x.movedim(dim, -1)
xp = xp.movedim(dim, -1)
fp = fp.movedim(dim, -1)
m = torch.diff(fp) / torch.diff(xp) # slope
b = fp[..., :-1] - m * xp[..., :-1] # offset
# Ensure contiguous inputs for searchsorted
xp = xp.contiguous()
x = x.contiguous()
indices = torch.searchsorted(xp, x, right=False)
if extrapolate == "constant":
# Pad m and b to get constant values outside of xp range
m = torch.cat([torch.zeros_like(m)[..., :1], m, torch.zeros_like(m)[..., :1]], dim=-1)
b = torch.cat([fp[..., :1], b, fp[..., -1:]], dim=-1)
else: # extrapolate == 'linear'
indices = torch.clamp(indices - 1, 0, m.shape[-1] - 1)
values = m.gather(-1, indices) * x + b.gather(-1, indices)
values = values.clamp(min=0.0)
return values.movedim(-1, dim)
class PiecewiseLinearDist:
"""Custom distribution for piecewise linear density estimation.
Mimics the torch.distributions interface with log_prob, sample, and mode methods.
Attributes:
xs: Nested list of x-coordinates [R][L][F][C] where R=repetitions, L=leaves,
F=features, C=channels.
ys: Nested list of y-coordinates (densities) with same structure as xs.
domains: List of Domain objects, one per feature.
"""
def __init__(self, xs: List, ys: List, domains: List[Domain]):
"""Initialize the piecewise linear distribution.
Args:
xs: Nested list of x-coordinates for piecewise linear functions.
ys: Nested list of y-coordinates (densities) for piecewise linear functions.
domains: List of Domain objects describing each feature's domain.
"""
self.xs = xs
self.ys = ys
self.domains = domains
self.num_repetitions = len(xs)
self.num_leaves = len(xs[0])
self.num_features = len(xs[0][0])
self.num_channels = len(xs[0][0][0])
self._optimized_cache_ready = False
self._continuous_flat_indices: Tensor | None = None
self._flat_feature_indices: Tensor | None = None
self._flat_channel_indices: Tensor | None = None
self._flat_leaf_indices: Tensor | None = None
self._flat_repetition_indices: Tensor | None = None
self._xs_padded: Tensor | None = None
self._ys_padded: Tensor | None = None
self._lengths: Tensor | None = None
self._interp_slopes: Tensor | None = None
self._interp_offsets: Tensor | None = None
self._mode_values: Tensor | None = None
self._cdf_padded: Tensor | None = None
self._cdf_lengths: Tensor | None = None
@property
def _num_distributions(self) -> int:
return self.num_features * self.num_channels * self.num_repetitions * self.num_leaves
def _ensure_optimized_cache(self) -> None:
"""Pack nested parameter lists into padded tensors for batched kernels."""
if self._optimized_cache_ready:
return
device = self.xs[0][0][0][0].device
dtype = self.xs[0][0][0][0].dtype
flat_xs: list[Tensor] = []
flat_ys: list[Tensor] = []
feature_indices: list[int] = []
channel_indices: list[int] = []
leaf_indices: list[int] = []
repetition_indices: list[int] = []
continuous_flat_indices: list[int] = []
flat_idx = 0
for i_feature in range(self.num_features):
for i_channel in range(self.num_channels):
for i_repetition in range(self.num_repetitions):
for i_leaf in range(self.num_leaves):
flat_xs.append(self.xs[i_repetition][i_leaf][i_feature][i_channel])
flat_ys.append(self.ys[i_repetition][i_leaf][i_feature][i_channel])
feature_indices.append(i_feature)
channel_indices.append(i_channel)
leaf_indices.append(i_leaf)
repetition_indices.append(i_repetition)
if self.domains[i_feature].data_type == DataType.CONTINUOUS:
continuous_flat_indices.append(flat_idx)
flat_idx += 1
max_points = max(int(x.shape[0]) for x in flat_xs)
xs_padded = torch.full((self._num_distributions, max_points), float("inf"), device=device, dtype=dtype)
ys_padded = torch.zeros((self._num_distributions, max_points), device=device, dtype=dtype)
lengths = torch.empty((self._num_distributions,), device=device, dtype=torch.long)
for idx, (xs_i, ys_i) in enumerate(zip(flat_xs, flat_ys)):
length = int(xs_i.shape[0])
xs_padded[idx, :length] = xs_i
ys_padded[idx, :length] = ys_i
lengths[idx] = length
interval_mask = torch.arange(max_points - 1, device=device).unsqueeze(0) < (lengths - 1).unsqueeze(1)
delta_x = torch.diff(xs_padded, dim=1)
delta_y = torch.diff(ys_padded, dim=1)
safe_delta_x = torch.where(interval_mask, delta_x, torch.ones_like(delta_x))
slopes = torch.where(interval_mask, delta_y / safe_delta_x, torch.zeros_like(delta_y))
offsets = torch.where(
interval_mask,
ys_padded[:, :-1] - slopes * xs_padded[:, :-1],
torch.zeros_like(delta_y),
)
zeros_edge = torch.zeros((self._num_distributions, 1), device=device, dtype=dtype)
last_indices = (lengths - 1).unsqueeze(1)
first_values = ys_padded[:, :1]
last_values = ys_padded.gather(dim=1, index=last_indices)
interp_slopes = torch.cat([zeros_edge, slopes, zeros_edge], dim=1)
interp_offsets = torch.cat([first_values, offsets, last_values], dim=1)
mode_scores = ys_padded.masked_fill(
torch.arange(max_points, device=device).unsqueeze(0) >= lengths.unsqueeze(1),
float("-inf"),
)
mode_indices = torch.argmax(mode_scores, dim=1, keepdim=True)
mode_values = xs_padded.gather(dim=1, index=mode_indices).squeeze(1)
continuous_index_tensor = torch.tensor(continuous_flat_indices, device=device, dtype=torch.long)
if continuous_index_tensor.numel() > 0:
cdf_padded = torch.full(
(continuous_index_tensor.numel(), max_points),
float("inf"),
device=device,
dtype=dtype,
)
cdf_lengths = lengths[continuous_index_tensor]
xs_cont = xs_padded[continuous_index_tensor]
ys_cont = ys_padded[continuous_index_tensor]
for idx in range(int(continuous_index_tensor.numel())):
length = int(cdf_lengths[idx])
intervals = xs_cont[idx, 1:length] - xs_cont[idx, : length - 1]
trapezoids = 0.5 * intervals * (ys_cont[idx, : length - 1] + ys_cont[idx, 1:length])
cdf = torch.cat(
[
torch.zeros((1,), device=device, dtype=dtype),
torch.cumsum(trapezoids, dim=0),
]
)
cdf = cdf / (cdf[-1] + 1e-10)
cdf_padded[idx, :length] = cdf
else:
cdf_padded = torch.empty((0, max_points), device=device, dtype=dtype)
cdf_lengths = torch.empty((0,), device=device, dtype=torch.long)
self._flat_feature_indices = torch.tensor(feature_indices, device=device, dtype=torch.long)
self._flat_channel_indices = torch.tensor(channel_indices, device=device, dtype=torch.long)
self._flat_leaf_indices = torch.tensor(leaf_indices, device=device, dtype=torch.long)
self._flat_repetition_indices = torch.tensor(repetition_indices, device=device, dtype=torch.long)
self._continuous_flat_indices = continuous_index_tensor
self._xs_padded = xs_padded
self._ys_padded = ys_padded
self._lengths = lengths
self._interp_slopes = interp_slopes
self._interp_offsets = interp_offsets
self._mode_values = mode_values
self._cdf_padded = cdf_padded
self._cdf_lengths = cdf_lengths
self._optimized_cache_ready = True
def _reshape_flat(self, values: Tensor) -> Tensor:
"""Reshape flat distribution order [F, C, R, L] into [C, F, L, R]."""
return values.view(self.num_features, self.num_channels, self.num_repetitions, self.num_leaves).permute(
1, 0, 3, 2
)
def _reshape_flat_with_batch(self, values: Tensor) -> Tensor:
"""Reshape flat distribution order [N, F, C, R, L] into [N, C, F, L, R]."""
return values.view(
values.shape[0],
self.num_features,
self.num_channels,
self.num_repetitions,
self.num_leaves,
).permute(0, 2, 1, 4, 3)
def _compute_cdf(self, xs: Tensor, ys: Tensor) -> Tensor:
"""Compute the CDF for the given piecewise linear function.
Args:
xs: X-coordinates of the piecewise function.
ys: Y-coordinates (densities) of the piecewise function.
Returns:
CDF values at each x-coordinate.
"""
# Compute the integral over each interval using the trapezoid rule
intervals = torch.diff(xs)
trapezoids = 0.5 * intervals * (ys[:-1] + ys[1:]) # Partial areas
# Cumulative sum to build the CDF
cdf = torch.cat([torch.zeros(1, device=xs.device), torch.cumsum(trapezoids, dim=0)])
# Normalize the CDF to ensure it goes from 0 to 1
cdf = cdf / (cdf[-1] + 1e-10)
return cdf
def sample(self, sample_shape: torch.Size | tuple[int, ...]) -> Tensor:
"""Sample from the piecewise linear distribution.
Args:
sample_shape: Shape of samples to generate.
Returns:
Samples tensor of shape (sample_shape[0], C, F, L, R).
"""
self._ensure_optimized_cache()
assert self._continuous_flat_indices is not None
assert self._cdf_lengths is not None
assert self._cdf_padded is not None
assert self._xs_padded is not None
num_samples = sample_shape[0] if isinstance(sample_shape, torch.Size) else sample_shape[0]
flat_samples = torch.empty(
(num_samples, self._num_distributions),
device=self.xs[0][0][0][0].device,
dtype=self.xs[0][0][0][0].dtype,
)
if self._continuous_flat_indices.numel() > 0:
uniforms = torch.empty(
(self._continuous_flat_indices.numel(), num_samples),
device=flat_samples.device,
dtype=flat_samples.dtype,
)
for idx in range(int(self._continuous_flat_indices.numel())):
uniforms[idx] = torch.rand(num_samples, device=flat_samples.device, dtype=flat_samples.dtype)
indices = torch.searchsorted(self._cdf_padded, uniforms, right=True)
max_indices = (self._cdf_lengths - 1).unsqueeze(1)
indices = torch.minimum(indices, max_indices)
indices = torch.clamp(indices, min=1)
xs_cont = self._xs_padded[self._continuous_flat_indices]
cdf0 = self._cdf_padded.gather(dim=1, index=indices - 1)
cdf1 = self._cdf_padded.gather(dim=1, index=indices)
x0 = xs_cont.gather(dim=1, index=indices - 1)
x1 = xs_cont.gather(dim=1, index=indices)
slope = (x1 - x0) / (cdf1 - cdf0 + 1e-8)
cont_samples = x0 + slope * (uniforms - cdf0)
flat_samples[:, self._continuous_flat_indices] = cont_samples.transpose(0, 1)
continuous_lookup = set(self._continuous_flat_indices.tolist())
for flat_idx in range(self._num_distributions):
if flat_idx in continuous_lookup:
continue
i_feature = int(self._flat_feature_indices[flat_idx])
i_channel = int(self._flat_channel_indices[flat_idx])
i_leaf = int(self._flat_leaf_indices[flat_idx])
i_repetition = int(self._flat_repetition_indices[flat_idx])
ys_i = self.ys[i_repetition][i_leaf][i_feature][i_channel]
if self.domains[i_feature].data_type != DataType.DISCRETE:
raise ValueError(f"Unknown data type: {self.domains[i_feature].data_type}")
dist = torch.distributions.Categorical(probs=ys_i[1:-1])
flat_samples[:, flat_idx] = dist.sample(sample_shape)
return self._reshape_flat_with_batch(flat_samples)
@property
def mode(self) -> Tensor:
"""Compute the mode of the distribution.
Returns:
Modes tensor of shape (C, F, L, R).
"""
self._ensure_optimized_cache()
assert self._mode_values is not None
return self._reshape_flat(self._mode_values)
def log_prob(self, x: Tensor) -> Tensor:
"""Compute log probabilities for input data.
Args:
x: Input tensor of shape (N, C, F, 1, 1) or (N, C, F).
Returns:
Log probabilities of shape (N, C, F, L, R).
"""
if x.dim() == 5:
x = rearrange(x, "n c f 1 1 -> n c f")
self._ensure_optimized_cache()
assert self._flat_channel_indices is not None
assert self._flat_feature_indices is not None
assert self._xs_padded is not None
assert self._interp_slopes is not None
assert self._interp_offsets is not None
flat_queries = x[:, self._flat_channel_indices, self._flat_feature_indices]
query_matrix = flat_queries.transpose(0, 1).contiguous()
indices = torch.searchsorted(self._xs_padded, query_matrix, right=False)
probs_flat = self._interp_slopes.gather(dim=1, index=indices) * query_matrix + self._interp_offsets.gather(
dim=1, index=indices
)
probs_flat = probs_flat.clamp(min=0.0).transpose(0, 1)
logprobs = torch.log(probs_flat + 1e-10)
logprobs = torch.clamp(logprobs, min=-300.0)
return self._reshape_flat_with_batch(logprobs)
[docs]
class PiecewiseLinear(LeafModule):
"""Piecewise linear leaf distribution module.
First constructs histograms from the data using K-means clustering,
then approximates the histograms with piecewise linear functions.
This leaf requires initialization with data via the `initialize()` method
before it can be used for inference or sampling.
Attributes:
alpha: Laplace smoothing parameter.
xs: Nested list of x-coordinates for piecewise linear functions.
ys: Nested list of y-coordinates (densities) for piecewise linear functions.
domains: List of Domain objects describing each feature.
is_initialized: Whether the distribution has been initialized with data.
"""
[docs]
def __init__(
self,
scope: Scope | int | List[int],
out_channels: int = 1,
num_repetitions: int = 1,
alpha: float = 0.0,
):
"""Initialize PiecewiseLinear leaf module.
Args:
scope: Variable scope (Scope, int, or list[int]).
out_channels: Number of output channels (clusters via K-means).
num_repetitions: Number of repetitions.
alpha: Laplace smoothing parameter (default 0.0).
"""
super().__init__(
scope=scope,
out_channels=out_channels,
num_repetitions=num_repetitions,
)
if alpha < 0:
raise ValueError(f"alpha must be non-negative, got {alpha}")
self.alpha = alpha
# These will be set during initialization
self.xs: Optional[List] = None
self.ys: Optional[List] = None
self.domains: Optional[List[Domain]] = None
self.is_initialized = False
self._distribution_cache: Optional[PiecewiseLinearDist] = None
# Register a dummy parameter so device detection works
self.register_buffer("_device_buffer", torch.zeros(1))
@property
def _torch_distribution_class(self):
"""PiecewiseLinear uses a custom distribution, not a torch.distributions class."""
return None
@property
def _supported_value(self) -> float:
"""Returns a value in the support of the distribution."""
return 0.0
[docs]
def distribution(self, with_differentiable_sampling: bool = False) -> PiecewiseLinearDist:
"""Return the underlying PiecewiseLinearDist object.
Args:
with_differentiable_sampling: Whether to request a differentiable
sampling distribution.
Raises:
ValueError: If the distribution has not been initialized.
"""
if with_differentiable_sampling:
raise NotImplementedError(
"PiecewiseLinear does not support differentiable sampling. "
"Use distribution(with_differentiable_sampling=False)."
)
if not self.is_initialized:
raise ValueError(
"PiecewiseLinear leaf has not been initialized. " "Call initialize(data, domains) first."
)
if self._distribution_cache is None:
self._distribution_cache = PiecewiseLinearDist(self.xs, self.ys, self.domains) # type: ignore[arg-type]
return self._distribution_cache
@property
def mode(self) -> Tensor:
"""Return distribution mode.
Returns:
Mode of the distribution.
"""
return self.distribution().mode
[docs]
def params(self) -> dict:
"""Returns the parameters of the distribution.
For PiecewiseLinear, returns xs and ys nested lists.
"""
return {"xs": self.xs, "ys": self.ys}
def _compute_parameter_estimates(self, data: Tensor, weights: Tensor, bias_correction: bool) -> dict:
"""Not implemented for PiecewiseLinear - use initialize() instead."""
raise NotImplementedError("PiecewiseLinear does not support MLE. Use initialize() instead.")
[docs]
def initialize(self, data: Tensor, domains: List[Domain]) -> None:
"""Initialize the piecewise linear distribution with data.
Uses K-means clustering to create multiple distributions per leaf,
then constructs histograms and approximates them with piecewise
linear functions.
Args:
data: Training data tensor of shape (N, F) where N is batch size
and F is the number of features.
domains: List of Domain objects, one per feature.
Raises:
ValueError: If data shape doesn't match scope.
"""
try:
from fast_pytorch_kmeans import KMeans
except ImportError as e:
raise OptionalDependencyError(
"fast_pytorch_kmeans required for PiecewiseLinear. "
"Install with: pip install fast-pytorch-kmeans"
) from e
logger.info(f"Initializing PiecewiseLinear with data shape {data.shape}")
# Validate input
num_features = len(self.scope.query)
if data.shape[1] != num_features:
raise ValueError(f"Data has {data.shape[1]} features but scope has {num_features}")
if len(domains) != num_features:
raise ValueError(f"Got {len(domains)} domains but scope has {num_features} features")
self.domains = domains
device = data.device
# Parameters stored as nested lists [R][L][F][C]
xs = []
ys = []
num_leaves = self.out_shape.channels
for i_repetition in range(self.out_shape.repetitions):
xs_leaves = []
ys_leaves = []
# Cluster data into num_leaves clusters
if num_leaves > 1:
kmeans = KMeans(n_clusters=num_leaves, mode="euclidean", verbose=0, init_method="random")
kmeans.fit(data.float())
cluster_idxs = kmeans.max_sim(a=data.float(), b=kmeans.centroids)[1]
else:
cluster_idxs = torch.zeros(data.shape[0], dtype=torch.long, device=device)
for cluster_idx in range(num_leaves):
# Select data for this cluster
mask = cluster_idxs == cluster_idx
cluster_data = data[mask]
xs_features = []
ys_features = []
for i_feature in range(num_features):
xs_channels = []
ys_channels = []
# For PiecewiseLinear, we use a single "channel" per feature
# (the reference used num_channels but SPFlow uses out_channels for leaves)
data_subset = cluster_data[:, i_feature].float()
if self.domains[i_feature].data_type == DataType.DISCRETE:
# Edges are the discrete values
mids = torch.as_tensor(
self.domains[i_feature].values, device=device, dtype=torch.float32
)
# Add a break at the end
breaks = torch.cat([mids, mids[-1:].add(1)])
if data_subset.shape[0] == 0:
# If no data in cluster, use uniform
densities = torch.ones(len(mids), device=device) / len(mids)
else:
# Compute histogram densities
densities = torch.histogram(
data_subset.cpu(), bins=breaks.cpu(), density=True
).hist.to(device)
elif self.domains[i_feature].data_type == DataType.CONTINUOUS:
# Find histogram bins using automatic bin width
if data_subset.numel() > 0:
bins, _ = get_bin_edges_torch(data_subset)
else:
# Fallback for empty data
bins = torch.linspace(
self.domains[i_feature].min or 0,
self.domains[i_feature].max or 1,
11,
device=device,
)
# Construct histogram
if data_subset.numel() > 0:
densities = torch.histogram(
data_subset.cpu(), bins=bins.cpu(), density=True
).hist.to(device)
else:
densities = torch.ones(len(bins) - 1, device=device) / (len(bins) - 1)
breaks = bins
mids = ((breaks + torch.roll(breaks, shifts=-1, dims=0)) / 2)[:-1]
else:
raise ValueError(f"Unknown data type: {domains[i_feature].data_type}")
# Apply optional Laplace smoothing
if self.alpha > 0:
n_samples = data_subset.shape[0]
n_bins = len(breaks) - 1
counts = densities * n_samples
densities = (counts + self.alpha) / (n_samples + n_bins * self.alpha)
# Add tail breaks to start and end
if self.domains[i_feature].data_type == DataType.DISCRETE:
tail_width = 1
x = [b.item() for b in breaks[:-1]]
x = [x[0] - tail_width] + x + [x[-1] + tail_width]
elif self.domains[i_feature].data_type == DataType.CONTINUOUS:
EPS = 1e-8
x = (
[breaks[0].item() - EPS]
+ [b0.item() + (b1.item() - b0.item()) / 2 for (b0, b1) in pairwise(breaks)]
+ [breaks[-1].item() + EPS]
)
else:
raise ValueError(
f"Unknown data type in tail break construction: {self.domains[i_feature].data_type}"
)
# Add density 0 at start and end tail breaks
y = [0.0] + [d.item() for d in densities] + [0.0]
# Construct tensors
x = torch.tensor(x, device=device, dtype=torch.float32)
y = torch.tensor(y, device=device, dtype=torch.float32)
# Compute AUC using the trapeziod rule
auc = torch.trapezoid(y=y, x=x)
# Normalize y to sum to 1 using AUC
if auc > 0:
y = y / auc
xs_channels.append(x)
ys_channels.append(y)
xs_features.append(xs_channels)
ys_features.append(ys_channels)
xs_leaves.append(xs_features)
ys_leaves.append(ys_features)
xs.append(xs_leaves)
ys.append(ys_leaves)
self.xs = xs
self.ys = ys
self.is_initialized = True
self._distribution_cache = None
logger.info("PiecewiseLinear initialization complete")
[docs]
def reset(self) -> None:
"""Reset the distribution to uninitialized state."""
self.is_initialized = False
self.xs = None
self.ys = None
self.domains = None
self._distribution_cache = None
[docs]
def log_likelihood(
self,
data: Tensor,
cache: Cache | None = None,
) -> Tensor:
"""Compute log-likelihoods for input data.
Args:
data: Input data tensor of shape (N, F).
cache: Optional cache dictionary.
Returns:
Log-likelihood tensor.
"""
if not self.is_initialized:
raise ValueError(
"PiecewiseLinear leaf has not been initialized. " "Call initialize(data, domains) first."
)
if data.dim() != 2:
raise ValueError(f"Data must be 2-dimensional (batch, num_features), got shape {data.shape}.")
# Get scope-relevant data
data_q = data[:, self.scope.query]
# Handle marginalization
marg_mask = torch.isnan(data_q)
has_marginalizations = marg_mask.any()
if has_marginalizations:
data_q = data_q.clone()
data_q[marg_mask] = self._supported_value
# Unsqueeze to add channel dimension
data_q = rearrange(data_q, "n f -> n 1 f")
# Compute log probabilities
dist = self.distribution()
log_prob = dist.log_prob(data_q)
# Marginalize entries
if has_marginalizations:
# Expand mask to match log_prob shape
marg_mask_expanded = rearrange(marg_mask, "n f -> n 1 f 1 1")
marg_mask_expanded = torch.broadcast_to(marg_mask_expanded, log_prob.shape)
log_prob[marg_mask_expanded] = 0.0
return log_prob
def _sample(
self,
data: Tensor,
sampling_ctx: SamplingContext,
cache: Cache,
) -> Tensor:
"""Sample from the piecewise linear distribution.
Args:
num_samples: Number of samples to generate.
data: Optional evidence tensor.
is_mpe: Perform MPE (mode) instead of sampling.
cache: Optional cache dictionary.
sampling_ctx: Optional sampling context.
Returns:
Sampled data tensor.
"""
if sampling_ctx.return_leaf_params:
raise UnsupportedOperationError(
"PiecewiseLinear.sample() does not support return_leaf_params=True yet."
)
if sampling_ctx.is_differentiable:
raise UnsupportedOperationError(
"PiecewiseLinear.sample() does not support differentiable routing yet."
)
if not self.is_initialized:
raise ValueError(
"PiecewiseLinear leaf has not been initialized. " "Call initialize(data, domains) first."
)
# Prepare data tensor
sampling_ctx.validate_sampling_context(
num_samples=data.shape[0],
num_features=self.out_shape.features,
num_channels=self.out_shape.channels,
num_repetitions=self.out_shape.repetitions,
allowed_feature_widths=(1, self.out_shape.features, data.shape[1]),
)
scope_cols = self._resolve_scope_columns(num_features=data.shape[1])
out_of_scope = list(filter(lambda x: x not in scope_cols, range(data.shape[1])))
marg_mask = torch.isnan(data)
marg_mask[:, out_of_scope] = False
# Mask that tells us which feature at which sample is relevant
samples_mask = marg_mask
ctx_channel_index, ctx_mask = self._slice_sampling_context(
sampling_ctx=sampling_ctx,
num_features=data.shape[1],
scope_cols=scope_cols,
)
samples_mask[:, scope_cols] &= ctx_mask
# Count number of samples to draw
instance_mask = samples_mask.sum(1) > 0
n_samples = instance_mask.sum()
dist = self.distribution(with_differentiable_sampling=sampling_ctx.is_differentiable)
n_samples_int = int(n_samples.item())
if sampling_ctx.is_mpe:
samples = rearrange(dist.mode, "c f l r -> 1 c f l r")
samples = repeat(samples, "1 c f l r -> n c f l r", n=n_samples_int).detach()
else:
samples = dist.sample((n_samples_int,))
# Handle repetition index
if samples.ndim == 5:
repetition_index = sampling_ctx.repetition_index[instance_mask]
num_channels = samples.shape[1]
num_features = samples.shape[2]
num_leaves = samples.shape[3]
r_idxs = repeat(
rearrange(repetition_index, "n -> n 1 1 1 1"),
"n 1 1 1 1 -> n c f l 1",
c=num_channels,
f=num_features,
l=num_leaves,
)
samples = rearrange(torch.gather(samples, dim=-1, index=r_idxs), "n c f l 1 -> n c f l")
# Handle channel index - gather on leaves dimension (dim=3)
# samples shape after repetition handling: (N, C=1, F, L)
if self.out_shape.channels == 1:
sampling_ctx.channel_index = torch.zeros_like(sampling_ctx.channel_index)
# c_idxs needs shape (N, 1, F, 1) to gather on dim=3
c_idxs = ctx_channel_index[instance_mask]
num_features = samples.shape[2]
if c_idxs.dim() == 1:
c_idxs = c_idxs.unsqueeze(1)
if c_idxs.shape[1] == 1 and num_features > 1:
c_idxs = c_idxs.expand(-1, num_features)
elif c_idxs.shape[1] != num_features:
raise ShapeError(
"sampling_ctx.channel_index has incompatible feature width for PiecewiseLinear.sample: "
f"got {c_idxs.shape[1]}, expected 1 or {num_features}."
)
c_idxs = rearrange(c_idxs.to(torch.long), "n f -> n 1 f 1")
samples = samples.gather(dim=3, index=c_idxs).squeeze(3) # (N, 1, F)
# Squeeze channel dimension
samples = rearrange(samples, "n 1 f -> n f")
# Update data with samples
row_indices = instance_mask.nonzero(as_tuple=True)[0]
scope_idx = torch.tensor(scope_cols, dtype=torch.long, device=data.device)
num_scope_features = len(scope_idx)
rows = repeat(row_indices, "n -> n s", s=num_scope_features)
cols = repeat(scope_idx, "s -> n s", n=n_samples_int)
mask_subset = samples_mask[instance_mask][:, scope_cols]
data[rows[mask_subset], cols[mask_subset]] = samples[mask_subset].to(data.dtype)
return data