Source code for spflow.modules.leaves.piecewise_linear

"""Piecewise linear leaf distribution module.

This module provides a non-parametric density estimation approach that
approximates data distributions using piecewise linear functions constructed
from histograms. It uses K-means clustering to create multiple distributions
per leaf.
"""

from __future__ import annotations

import itertools
import logging
from typing import List, Optional

import torch
from einops import rearrange, repeat
from torch import Tensor, nn

from spflow.exceptions import OptionalDependencyError, ShapeError, UnsupportedOperationError
from spflow.meta.data.scope import Scope
from spflow.modules.leaves.leaf import LeafModule
from spflow.utils.cache import Cache
from spflow.utils.domain import DataType, Domain
from spflow.utils.histogram import get_bin_edges_torch
from spflow.utils.sampling_context import SamplingContext

logger = logging.getLogger(__name__)


def pairwise(iterable):
    """Iterate over consecutive pairs.

    s -> (s0,s1), (s1,s2), (s2, s3), ...
    """
    a, b = itertools.tee(iterable)
    next(b, None)
    return zip(a, b)


def interp(x: Tensor, xp: Tensor, fp: Tensor, dim: int = -1, extrapolate: str = "constant") -> Tensor:
    """One-dimensional linear interpolation between monotonically increasing sample points.

    Returns the one-dimensional piecewise linear interpolant to a function with
    given discrete data points (xp, fp), evaluated at x.

    Source: https://github.com/pytorch/pytorch/issues/50334#issuecomment-2304751532

    Args:
        x: The x-coordinates at which to evaluate the interpolated values.
        xp: The x-coordinates of the data points, must be increasing.
        fp: The y-coordinates of the data points, same shape as xp.
        dim: Dimension across which to interpolate.
        extrapolate: How to handle values outside the range of xp. Options:
            - 'linear': Extrapolate linearly beyond range.
            - 'constant': Use boundary value of fp for x outside xp.

    Returns:
        The interpolated values, same size as x.
    """
    # Move the interpolation dimension to the last axis
    x = x.movedim(dim, -1)
    xp = xp.movedim(dim, -1)
    fp = fp.movedim(dim, -1)

    m = torch.diff(fp) / torch.diff(xp)  # slope
    b = fp[..., :-1] - m * xp[..., :-1]  # offset

    # Ensure contiguous inputs for searchsorted
    xp = xp.contiguous()
    x = x.contiguous()
    indices = torch.searchsorted(xp, x, right=False)

    if extrapolate == "constant":
        # Pad m and b to get constant values outside of xp range
        m = torch.cat([torch.zeros_like(m)[..., :1], m, torch.zeros_like(m)[..., :1]], dim=-1)
        b = torch.cat([fp[..., :1], b, fp[..., -1:]], dim=-1)
    else:  # extrapolate == 'linear'
        indices = torch.clamp(indices - 1, 0, m.shape[-1] - 1)

    values = m.gather(-1, indices) * x + b.gather(-1, indices)

    values = values.clamp(min=0.0)

    return values.movedim(-1, dim)


class PiecewiseLinearDist:
    """Custom distribution for piecewise linear density estimation.

    Mimics the torch.distributions interface with log_prob, sample, and mode methods.

    Attributes:
        xs: Nested list of x-coordinates [R][L][F][C] where R=repetitions, L=leaves,
            F=features, C=channels.
        ys: Nested list of y-coordinates (densities) with same structure as xs.
        domains: List of Domain objects, one per feature.
    """

    def __init__(self, xs: List, ys: List, domains: List[Domain]):
        """Initialize the piecewise linear distribution.

        Args:
            xs: Nested list of x-coordinates for piecewise linear functions.
            ys: Nested list of y-coordinates (densities) for piecewise linear functions.
            domains: List of Domain objects describing each feature's domain.
        """
        self.xs = xs
        self.ys = ys
        self.domains = domains

        self.num_repetitions = len(xs)
        self.num_leaves = len(xs[0])
        self.num_features = len(xs[0][0])
        self.num_channels = len(xs[0][0][0])
        self._optimized_cache_ready = False
        self._continuous_flat_indices: Tensor | None = None
        self._flat_feature_indices: Tensor | None = None
        self._flat_channel_indices: Tensor | None = None
        self._flat_leaf_indices: Tensor | None = None
        self._flat_repetition_indices: Tensor | None = None
        self._xs_padded: Tensor | None = None
        self._ys_padded: Tensor | None = None
        self._lengths: Tensor | None = None
        self._interp_slopes: Tensor | None = None
        self._interp_offsets: Tensor | None = None
        self._mode_values: Tensor | None = None
        self._cdf_padded: Tensor | None = None
        self._cdf_lengths: Tensor | None = None

    @property
    def _num_distributions(self) -> int:
        return self.num_features * self.num_channels * self.num_repetitions * self.num_leaves

    def _ensure_optimized_cache(self) -> None:
        """Pack nested parameter lists into padded tensors for batched kernels."""
        if self._optimized_cache_ready:
            return

        device = self.xs[0][0][0][0].device
        dtype = self.xs[0][0][0][0].dtype

        flat_xs: list[Tensor] = []
        flat_ys: list[Tensor] = []
        feature_indices: list[int] = []
        channel_indices: list[int] = []
        leaf_indices: list[int] = []
        repetition_indices: list[int] = []
        continuous_flat_indices: list[int] = []

        flat_idx = 0
        for i_feature in range(self.num_features):
            for i_channel in range(self.num_channels):
                for i_repetition in range(self.num_repetitions):
                    for i_leaf in range(self.num_leaves):
                        flat_xs.append(self.xs[i_repetition][i_leaf][i_feature][i_channel])
                        flat_ys.append(self.ys[i_repetition][i_leaf][i_feature][i_channel])
                        feature_indices.append(i_feature)
                        channel_indices.append(i_channel)
                        leaf_indices.append(i_leaf)
                        repetition_indices.append(i_repetition)
                        if self.domains[i_feature].data_type == DataType.CONTINUOUS:
                            continuous_flat_indices.append(flat_idx)
                        flat_idx += 1

        max_points = max(int(x.shape[0]) for x in flat_xs)
        xs_padded = torch.full((self._num_distributions, max_points), float("inf"), device=device, dtype=dtype)
        ys_padded = torch.zeros((self._num_distributions, max_points), device=device, dtype=dtype)
        lengths = torch.empty((self._num_distributions,), device=device, dtype=torch.long)

        for idx, (xs_i, ys_i) in enumerate(zip(flat_xs, flat_ys)):
            length = int(xs_i.shape[0])
            xs_padded[idx, :length] = xs_i
            ys_padded[idx, :length] = ys_i
            lengths[idx] = length

        interval_mask = torch.arange(max_points - 1, device=device).unsqueeze(0) < (lengths - 1).unsqueeze(1)
        delta_x = torch.diff(xs_padded, dim=1)
        delta_y = torch.diff(ys_padded, dim=1)
        safe_delta_x = torch.where(interval_mask, delta_x, torch.ones_like(delta_x))
        slopes = torch.where(interval_mask, delta_y / safe_delta_x, torch.zeros_like(delta_y))
        offsets = torch.where(
            interval_mask,
            ys_padded[:, :-1] - slopes * xs_padded[:, :-1],
            torch.zeros_like(delta_y),
        )
        zeros_edge = torch.zeros((self._num_distributions, 1), device=device, dtype=dtype)
        last_indices = (lengths - 1).unsqueeze(1)
        first_values = ys_padded[:, :1]
        last_values = ys_padded.gather(dim=1, index=last_indices)
        interp_slopes = torch.cat([zeros_edge, slopes, zeros_edge], dim=1)
        interp_offsets = torch.cat([first_values, offsets, last_values], dim=1)

        mode_scores = ys_padded.masked_fill(
            torch.arange(max_points, device=device).unsqueeze(0) >= lengths.unsqueeze(1),
            float("-inf"),
        )
        mode_indices = torch.argmax(mode_scores, dim=1, keepdim=True)
        mode_values = xs_padded.gather(dim=1, index=mode_indices).squeeze(1)

        continuous_index_tensor = torch.tensor(continuous_flat_indices, device=device, dtype=torch.long)
        if continuous_index_tensor.numel() > 0:
            cdf_padded = torch.full(
                (continuous_index_tensor.numel(), max_points),
                float("inf"),
                device=device,
                dtype=dtype,
            )
            cdf_lengths = lengths[continuous_index_tensor]
            xs_cont = xs_padded[continuous_index_tensor]
            ys_cont = ys_padded[continuous_index_tensor]
            for idx in range(int(continuous_index_tensor.numel())):
                length = int(cdf_lengths[idx])
                intervals = xs_cont[idx, 1:length] - xs_cont[idx, : length - 1]
                trapezoids = 0.5 * intervals * (ys_cont[idx, : length - 1] + ys_cont[idx, 1:length])
                cdf = torch.cat(
                    [
                        torch.zeros((1,), device=device, dtype=dtype),
                        torch.cumsum(trapezoids, dim=0),
                    ]
                )
                cdf = cdf / (cdf[-1] + 1e-10)
                cdf_padded[idx, :length] = cdf
        else:
            cdf_padded = torch.empty((0, max_points), device=device, dtype=dtype)
            cdf_lengths = torch.empty((0,), device=device, dtype=torch.long)

        self._flat_feature_indices = torch.tensor(feature_indices, device=device, dtype=torch.long)
        self._flat_channel_indices = torch.tensor(channel_indices, device=device, dtype=torch.long)
        self._flat_leaf_indices = torch.tensor(leaf_indices, device=device, dtype=torch.long)
        self._flat_repetition_indices = torch.tensor(repetition_indices, device=device, dtype=torch.long)
        self._continuous_flat_indices = continuous_index_tensor
        self._xs_padded = xs_padded
        self._ys_padded = ys_padded
        self._lengths = lengths
        self._interp_slopes = interp_slopes
        self._interp_offsets = interp_offsets
        self._mode_values = mode_values
        self._cdf_padded = cdf_padded
        self._cdf_lengths = cdf_lengths
        self._optimized_cache_ready = True

    def _reshape_flat(self, values: Tensor) -> Tensor:
        """Reshape flat distribution order [F, C, R, L] into [C, F, L, R]."""
        return values.view(self.num_features, self.num_channels, self.num_repetitions, self.num_leaves).permute(
            1, 0, 3, 2
        )

    def _reshape_flat_with_batch(self, values: Tensor) -> Tensor:
        """Reshape flat distribution order [N, F, C, R, L] into [N, C, F, L, R]."""
        return values.view(
            values.shape[0],
            self.num_features,
            self.num_channels,
            self.num_repetitions,
            self.num_leaves,
        ).permute(0, 2, 1, 4, 3)

    def _compute_cdf(self, xs: Tensor, ys: Tensor) -> Tensor:
        """Compute the CDF for the given piecewise linear function.

        Args:
            xs: X-coordinates of the piecewise function.
            ys: Y-coordinates (densities) of the piecewise function.

        Returns:
            CDF values at each x-coordinate.
        """
        # Compute the integral over each interval using the trapezoid rule
        intervals = torch.diff(xs)
        trapezoids = 0.5 * intervals * (ys[:-1] + ys[1:])  # Partial areas

        # Cumulative sum to build the CDF
        cdf = torch.cat([torch.zeros(1, device=xs.device), torch.cumsum(trapezoids, dim=0)])

        # Normalize the CDF to ensure it goes from 0 to 1
        cdf = cdf / (cdf[-1] + 1e-10)

        return cdf

    def sample(self, sample_shape: torch.Size | tuple[int, ...]) -> Tensor:
        """Sample from the piecewise linear distribution.

        Args:
            sample_shape: Shape of samples to generate.

        Returns:
            Samples tensor of shape (sample_shape[0], C, F, L, R).
        """
        self._ensure_optimized_cache()
        assert self._continuous_flat_indices is not None
        assert self._cdf_lengths is not None
        assert self._cdf_padded is not None
        assert self._xs_padded is not None

        num_samples = sample_shape[0] if isinstance(sample_shape, torch.Size) else sample_shape[0]
        flat_samples = torch.empty(
            (num_samples, self._num_distributions),
            device=self.xs[0][0][0][0].device,
            dtype=self.xs[0][0][0][0].dtype,
        )

        if self._continuous_flat_indices.numel() > 0:
            uniforms = torch.empty(
                (self._continuous_flat_indices.numel(), num_samples),
                device=flat_samples.device,
                dtype=flat_samples.dtype,
            )
            for idx in range(int(self._continuous_flat_indices.numel())):
                uniforms[idx] = torch.rand(num_samples, device=flat_samples.device, dtype=flat_samples.dtype)
            indices = torch.searchsorted(self._cdf_padded, uniforms, right=True)
            max_indices = (self._cdf_lengths - 1).unsqueeze(1)
            indices = torch.minimum(indices, max_indices)
            indices = torch.clamp(indices, min=1)

            xs_cont = self._xs_padded[self._continuous_flat_indices]
            cdf0 = self._cdf_padded.gather(dim=1, index=indices - 1)
            cdf1 = self._cdf_padded.gather(dim=1, index=indices)
            x0 = xs_cont.gather(dim=1, index=indices - 1)
            x1 = xs_cont.gather(dim=1, index=indices)
            slope = (x1 - x0) / (cdf1 - cdf0 + 1e-8)
            cont_samples = x0 + slope * (uniforms - cdf0)
            flat_samples[:, self._continuous_flat_indices] = cont_samples.transpose(0, 1)

        continuous_lookup = set(self._continuous_flat_indices.tolist())
        for flat_idx in range(self._num_distributions):
            if flat_idx in continuous_lookup:
                continue
            i_feature = int(self._flat_feature_indices[flat_idx])
            i_channel = int(self._flat_channel_indices[flat_idx])
            i_leaf = int(self._flat_leaf_indices[flat_idx])
            i_repetition = int(self._flat_repetition_indices[flat_idx])
            ys_i = self.ys[i_repetition][i_leaf][i_feature][i_channel]
            if self.domains[i_feature].data_type != DataType.DISCRETE:
                raise ValueError(f"Unknown data type: {self.domains[i_feature].data_type}")
            dist = torch.distributions.Categorical(probs=ys_i[1:-1])
            flat_samples[:, flat_idx] = dist.sample(sample_shape)

        return self._reshape_flat_with_batch(flat_samples)

    @property
    def mode(self) -> Tensor:
        """Compute the mode of the distribution.

        Returns:
            Modes tensor of shape (C, F, L, R).
        """
        self._ensure_optimized_cache()
        assert self._mode_values is not None
        return self._reshape_flat(self._mode_values)

    def log_prob(self, x: Tensor) -> Tensor:
        """Compute log probabilities for input data.

        Args:
            x: Input tensor of shape (N, C, F, 1, 1) or (N, C, F).

        Returns:
            Log probabilities of shape (N, C, F, L, R).
        """
        if x.dim() == 5:
            x = rearrange(x, "n c f 1 1 -> n c f")

        self._ensure_optimized_cache()
        assert self._flat_channel_indices is not None
        assert self._flat_feature_indices is not None
        assert self._xs_padded is not None
        assert self._interp_slopes is not None
        assert self._interp_offsets is not None

        flat_queries = x[:, self._flat_channel_indices, self._flat_feature_indices]
        query_matrix = flat_queries.transpose(0, 1).contiguous()
        indices = torch.searchsorted(self._xs_padded, query_matrix, right=False)
        probs_flat = self._interp_slopes.gather(dim=1, index=indices) * query_matrix + self._interp_offsets.gather(
            dim=1, index=indices
        )
        probs_flat = probs_flat.clamp(min=0.0).transpose(0, 1)
        logprobs = torch.log(probs_flat + 1e-10)
        logprobs = torch.clamp(logprobs, min=-300.0)
        return self._reshape_flat_with_batch(logprobs)


[docs] class PiecewiseLinear(LeafModule): """Piecewise linear leaf distribution module. First constructs histograms from the data using K-means clustering, then approximates the histograms with piecewise linear functions. This leaf requires initialization with data via the `initialize()` method before it can be used for inference or sampling. Attributes: alpha: Laplace smoothing parameter. xs: Nested list of x-coordinates for piecewise linear functions. ys: Nested list of y-coordinates (densities) for piecewise linear functions. domains: List of Domain objects describing each feature. is_initialized: Whether the distribution has been initialized with data. """
[docs] def __init__( self, scope: Scope | int | List[int], out_channels: int = 1, num_repetitions: int = 1, alpha: float = 0.0, ): """Initialize PiecewiseLinear leaf module. Args: scope: Variable scope (Scope, int, or list[int]). out_channels: Number of output channels (clusters via K-means). num_repetitions: Number of repetitions. alpha: Laplace smoothing parameter (default 0.0). """ super().__init__( scope=scope, out_channels=out_channels, num_repetitions=num_repetitions, ) if alpha < 0: raise ValueError(f"alpha must be non-negative, got {alpha}") self.alpha = alpha # These will be set during initialization self.xs: Optional[List] = None self.ys: Optional[List] = None self.domains: Optional[List[Domain]] = None self.is_initialized = False self._distribution_cache: Optional[PiecewiseLinearDist] = None # Register a dummy parameter so device detection works self.register_buffer("_device_buffer", torch.zeros(1))
@property def _torch_distribution_class(self): """PiecewiseLinear uses a custom distribution, not a torch.distributions class.""" return None @property def _supported_value(self) -> float: """Returns a value in the support of the distribution.""" return 0.0
[docs] def distribution(self, with_differentiable_sampling: bool = False) -> PiecewiseLinearDist: """Return the underlying PiecewiseLinearDist object. Args: with_differentiable_sampling: Whether to request a differentiable sampling distribution. Raises: ValueError: If the distribution has not been initialized. """ if with_differentiable_sampling: raise NotImplementedError( "PiecewiseLinear does not support differentiable sampling. " "Use distribution(with_differentiable_sampling=False)." ) if not self.is_initialized: raise ValueError( "PiecewiseLinear leaf has not been initialized. " "Call initialize(data, domains) first." ) if self._distribution_cache is None: self._distribution_cache = PiecewiseLinearDist(self.xs, self.ys, self.domains) # type: ignore[arg-type] return self._distribution_cache
@property def mode(self) -> Tensor: """Return distribution mode. Returns: Mode of the distribution. """ return self.distribution().mode
[docs] def params(self) -> dict: """Returns the parameters of the distribution. For PiecewiseLinear, returns xs and ys nested lists. """ return {"xs": self.xs, "ys": self.ys}
def _compute_parameter_estimates(self, data: Tensor, weights: Tensor, bias_correction: bool) -> dict: """Not implemented for PiecewiseLinear - use initialize() instead.""" raise NotImplementedError("PiecewiseLinear does not support MLE. Use initialize() instead.")
[docs] def initialize(self, data: Tensor, domains: List[Domain]) -> None: """Initialize the piecewise linear distribution with data. Uses K-means clustering to create multiple distributions per leaf, then constructs histograms and approximates them with piecewise linear functions. Args: data: Training data tensor of shape (N, F) where N is batch size and F is the number of features. domains: List of Domain objects, one per feature. Raises: ValueError: If data shape doesn't match scope. """ try: from fast_pytorch_kmeans import KMeans except ImportError as e: raise OptionalDependencyError( "fast_pytorch_kmeans required for PiecewiseLinear. " "Install with: pip install fast-pytorch-kmeans" ) from e logger.info(f"Initializing PiecewiseLinear with data shape {data.shape}") # Validate input num_features = len(self.scope.query) if data.shape[1] != num_features: raise ValueError(f"Data has {data.shape[1]} features but scope has {num_features}") if len(domains) != num_features: raise ValueError(f"Got {len(domains)} domains but scope has {num_features} features") self.domains = domains device = data.device # Parameters stored as nested lists [R][L][F][C] xs = [] ys = [] num_leaves = self.out_shape.channels for i_repetition in range(self.out_shape.repetitions): xs_leaves = [] ys_leaves = [] # Cluster data into num_leaves clusters if num_leaves > 1: kmeans = KMeans(n_clusters=num_leaves, mode="euclidean", verbose=0, init_method="random") kmeans.fit(data.float()) cluster_idxs = kmeans.max_sim(a=data.float(), b=kmeans.centroids)[1] else: cluster_idxs = torch.zeros(data.shape[0], dtype=torch.long, device=device) for cluster_idx in range(num_leaves): # Select data for this cluster mask = cluster_idxs == cluster_idx cluster_data = data[mask] xs_features = [] ys_features = [] for i_feature in range(num_features): xs_channels = [] ys_channels = [] # For PiecewiseLinear, we use a single "channel" per feature # (the reference used num_channels but SPFlow uses out_channels for leaves) data_subset = cluster_data[:, i_feature].float() if self.domains[i_feature].data_type == DataType.DISCRETE: # Edges are the discrete values mids = torch.as_tensor( self.domains[i_feature].values, device=device, dtype=torch.float32 ) # Add a break at the end breaks = torch.cat([mids, mids[-1:].add(1)]) if data_subset.shape[0] == 0: # If no data in cluster, use uniform densities = torch.ones(len(mids), device=device) / len(mids) else: # Compute histogram densities densities = torch.histogram( data_subset.cpu(), bins=breaks.cpu(), density=True ).hist.to(device) elif self.domains[i_feature].data_type == DataType.CONTINUOUS: # Find histogram bins using automatic bin width if data_subset.numel() > 0: bins, _ = get_bin_edges_torch(data_subset) else: # Fallback for empty data bins = torch.linspace( self.domains[i_feature].min or 0, self.domains[i_feature].max or 1, 11, device=device, ) # Construct histogram if data_subset.numel() > 0: densities = torch.histogram( data_subset.cpu(), bins=bins.cpu(), density=True ).hist.to(device) else: densities = torch.ones(len(bins) - 1, device=device) / (len(bins) - 1) breaks = bins mids = ((breaks + torch.roll(breaks, shifts=-1, dims=0)) / 2)[:-1] else: raise ValueError(f"Unknown data type: {domains[i_feature].data_type}") # Apply optional Laplace smoothing if self.alpha > 0: n_samples = data_subset.shape[0] n_bins = len(breaks) - 1 counts = densities * n_samples densities = (counts + self.alpha) / (n_samples + n_bins * self.alpha) # Add tail breaks to start and end if self.domains[i_feature].data_type == DataType.DISCRETE: tail_width = 1 x = [b.item() for b in breaks[:-1]] x = [x[0] - tail_width] + x + [x[-1] + tail_width] elif self.domains[i_feature].data_type == DataType.CONTINUOUS: EPS = 1e-8 x = ( [breaks[0].item() - EPS] + [b0.item() + (b1.item() - b0.item()) / 2 for (b0, b1) in pairwise(breaks)] + [breaks[-1].item() + EPS] ) else: raise ValueError( f"Unknown data type in tail break construction: {self.domains[i_feature].data_type}" ) # Add density 0 at start and end tail breaks y = [0.0] + [d.item() for d in densities] + [0.0] # Construct tensors x = torch.tensor(x, device=device, dtype=torch.float32) y = torch.tensor(y, device=device, dtype=torch.float32) # Compute AUC using the trapeziod rule auc = torch.trapezoid(y=y, x=x) # Normalize y to sum to 1 using AUC if auc > 0: y = y / auc xs_channels.append(x) ys_channels.append(y) xs_features.append(xs_channels) ys_features.append(ys_channels) xs_leaves.append(xs_features) ys_leaves.append(ys_features) xs.append(xs_leaves) ys.append(ys_leaves) self.xs = xs self.ys = ys self.is_initialized = True self._distribution_cache = None logger.info("PiecewiseLinear initialization complete")
[docs] def reset(self) -> None: """Reset the distribution to uninitialized state.""" self.is_initialized = False self.xs = None self.ys = None self.domains = None self._distribution_cache = None
[docs] def log_likelihood( self, data: Tensor, cache: Cache | None = None, ) -> Tensor: """Compute log-likelihoods for input data. Args: data: Input data tensor of shape (N, F). cache: Optional cache dictionary. Returns: Log-likelihood tensor. """ if not self.is_initialized: raise ValueError( "PiecewiseLinear leaf has not been initialized. " "Call initialize(data, domains) first." ) if data.dim() != 2: raise ValueError(f"Data must be 2-dimensional (batch, num_features), got shape {data.shape}.") # Get scope-relevant data data_q = data[:, self.scope.query] # Handle marginalization marg_mask = torch.isnan(data_q) has_marginalizations = marg_mask.any() if has_marginalizations: data_q = data_q.clone() data_q[marg_mask] = self._supported_value # Unsqueeze to add channel dimension data_q = rearrange(data_q, "n f -> n 1 f") # Compute log probabilities dist = self.distribution() log_prob = dist.log_prob(data_q) # Marginalize entries if has_marginalizations: # Expand mask to match log_prob shape marg_mask_expanded = rearrange(marg_mask, "n f -> n 1 f 1 1") marg_mask_expanded = torch.broadcast_to(marg_mask_expanded, log_prob.shape) log_prob[marg_mask_expanded] = 0.0 return log_prob
def _sample( self, data: Tensor, sampling_ctx: SamplingContext, cache: Cache, ) -> Tensor: """Sample from the piecewise linear distribution. Args: num_samples: Number of samples to generate. data: Optional evidence tensor. is_mpe: Perform MPE (mode) instead of sampling. cache: Optional cache dictionary. sampling_ctx: Optional sampling context. Returns: Sampled data tensor. """ if sampling_ctx.return_leaf_params: raise UnsupportedOperationError( "PiecewiseLinear.sample() does not support return_leaf_params=True yet." ) if sampling_ctx.is_differentiable: raise UnsupportedOperationError( "PiecewiseLinear.sample() does not support differentiable routing yet." ) if not self.is_initialized: raise ValueError( "PiecewiseLinear leaf has not been initialized. " "Call initialize(data, domains) first." ) # Prepare data tensor sampling_ctx.validate_sampling_context( num_samples=data.shape[0], num_features=self.out_shape.features, num_channels=self.out_shape.channels, num_repetitions=self.out_shape.repetitions, allowed_feature_widths=(1, self.out_shape.features, data.shape[1]), ) scope_cols = self._resolve_scope_columns(num_features=data.shape[1]) out_of_scope = list(filter(lambda x: x not in scope_cols, range(data.shape[1]))) marg_mask = torch.isnan(data) marg_mask[:, out_of_scope] = False # Mask that tells us which feature at which sample is relevant samples_mask = marg_mask ctx_channel_index, ctx_mask = self._slice_sampling_context( sampling_ctx=sampling_ctx, num_features=data.shape[1], scope_cols=scope_cols, ) samples_mask[:, scope_cols] &= ctx_mask # Count number of samples to draw instance_mask = samples_mask.sum(1) > 0 n_samples = instance_mask.sum() dist = self.distribution(with_differentiable_sampling=sampling_ctx.is_differentiable) n_samples_int = int(n_samples.item()) if sampling_ctx.is_mpe: samples = rearrange(dist.mode, "c f l r -> 1 c f l r") samples = repeat(samples, "1 c f l r -> n c f l r", n=n_samples_int).detach() else: samples = dist.sample((n_samples_int,)) # Handle repetition index if samples.ndim == 5: repetition_index = sampling_ctx.repetition_index[instance_mask] num_channels = samples.shape[1] num_features = samples.shape[2] num_leaves = samples.shape[3] r_idxs = repeat( rearrange(repetition_index, "n -> n 1 1 1 1"), "n 1 1 1 1 -> n c f l 1", c=num_channels, f=num_features, l=num_leaves, ) samples = rearrange(torch.gather(samples, dim=-1, index=r_idxs), "n c f l 1 -> n c f l") # Handle channel index - gather on leaves dimension (dim=3) # samples shape after repetition handling: (N, C=1, F, L) if self.out_shape.channels == 1: sampling_ctx.channel_index = torch.zeros_like(sampling_ctx.channel_index) # c_idxs needs shape (N, 1, F, 1) to gather on dim=3 c_idxs = ctx_channel_index[instance_mask] num_features = samples.shape[2] if c_idxs.dim() == 1: c_idxs = c_idxs.unsqueeze(1) if c_idxs.shape[1] == 1 and num_features > 1: c_idxs = c_idxs.expand(-1, num_features) elif c_idxs.shape[1] != num_features: raise ShapeError( "sampling_ctx.channel_index has incompatible feature width for PiecewiseLinear.sample: " f"got {c_idxs.shape[1]}, expected 1 or {num_features}." ) c_idxs = rearrange(c_idxs.to(torch.long), "n f -> n 1 f 1") samples = samples.gather(dim=3, index=c_idxs).squeeze(3) # (N, 1, F) # Squeeze channel dimension samples = rearrange(samples, "n 1 f -> n f") # Update data with samples row_indices = instance_mask.nonzero(as_tuple=True)[0] scope_idx = torch.tensor(scope_cols, dtype=torch.long, device=data.device) num_scope_features = len(scope_idx) rows = repeat(row_indices, "n -> n s", s=num_scope_features) cols = repeat(scope_idx, "s -> n s", n=n_samples_int) mask_subset = samples_mask[instance_mask][:, scope_cols] data[rows[mask_subset], cols[mask_subset]] = samples[mask_subset].to(data.dtype) return data