Source code for spflow.modules.leaves.piecewise_linear

"""Piecewise linear leaf distribution module.

This module provides a non-parametric density estimation approach that
approximates data distributions using piecewise linear functions constructed
from histograms. It uses K-means clustering to create multiple distributions
per leaf.
"""

from __future__ import annotations

import itertools
import logging
from typing import List, Optional

import torch
from torch import Tensor, nn

from spflow.exceptions import OptionalDependencyError
from spflow.meta.data.scope import Scope
from spflow.modules.leaves.leaf import LeafModule
from spflow.utils.cache import Cache
from spflow.utils.domain import DataType, Domain
from spflow.utils.histogram import get_bin_edges_torch
from spflow.utils.sampling_context import SamplingContext, init_default_sampling_context

logger = logging.getLogger(__name__)


def pairwise(iterable):
    """Iterate over consecutive pairs.

    s -> (s0,s1), (s1,s2), (s2, s3), ...
    """
    a, b = itertools.tee(iterable)
    next(b, None)
    return zip(a, b)


def interp(x: Tensor, xp: Tensor, fp: Tensor, dim: int = -1, extrapolate: str = "constant") -> Tensor:
    """One-dimensional linear interpolation between monotonically increasing sample points.

    Returns the one-dimensional piecewise linear interpolant to a function with
    given discrete data points (xp, fp), evaluated at x.

    Source: https://github.com/pytorch/pytorch/issues/50334#issuecomment-2304751532

    Args:
        x: The x-coordinates at which to evaluate the interpolated values.
        xp: The x-coordinates of the data points, must be increasing.
        fp: The y-coordinates of the data points, same shape as xp.
        dim: Dimension across which to interpolate.
        extrapolate: How to handle values outside the range of xp. Options:
            - 'linear': Extrapolate linearly beyond range.
            - 'constant': Use boundary value of fp for x outside xp.

    Returns:
        The interpolated values, same size as x.
    """
    # Move the interpolation dimension to the last axis
    x = x.movedim(dim, -1)
    xp = xp.movedim(dim, -1)
    fp = fp.movedim(dim, -1)

    m = torch.diff(fp) / torch.diff(xp)  # slope
    b = fp[..., :-1] - m * xp[..., :-1]  # offset

    # Ensure contiguous inputs for searchsorted
    xp = xp.contiguous()
    x = x.contiguous()
    indices = torch.searchsorted(xp, x, right=False)

    if extrapolate == "constant":
        # Pad m and b to get constant values outside of xp range
        m = torch.cat([torch.zeros_like(m)[..., :1], m, torch.zeros_like(m)[..., :1]], dim=-1)
        b = torch.cat([fp[..., :1], b, fp[..., -1:]], dim=-1)
    else:  # extrapolate == 'linear'
        indices = torch.clamp(indices - 1, 0, m.shape[-1] - 1)

    values = m.gather(-1, indices) * x + b.gather(-1, indices)

    values = values.clamp(min=0.0)

    return values.movedim(-1, dim)


class PiecewiseLinearDist:
    """Custom distribution for piecewise linear density estimation.

    Mimics the torch.distributions interface with log_prob, sample, and mode methods.

    Attributes:
        xs: Nested list of x-coordinates [R][L][F][C] where R=repetitions, L=leaves,
            F=features, C=channels.
        ys: Nested list of y-coordinates (densities) with same structure as xs.
        domains: List of Domain objects, one per feature.
    """

    def __init__(self, xs: List, ys: List, domains: List[Domain]):
        """Initialize the piecewise linear distribution.

        Args:
            xs: Nested list of x-coordinates for piecewise linear functions.
            ys: Nested list of y-coordinates (densities) for piecewise linear functions.
            domains: List of Domain objects describing each feature's domain.
        """
        self.xs = xs
        self.ys = ys
        self.domains = domains

        self.num_repetitions = len(xs)
        self.num_leaves = len(xs[0])
        self.num_features = len(xs[0][0])
        self.num_channels = len(xs[0][0][0])

    def _compute_cdf(self, xs: Tensor, ys: Tensor) -> Tensor:
        """Compute the CDF for the given piecewise linear function.

        Args:
            xs: X-coordinates of the piecewise function.
            ys: Y-coordinates (densities) of the piecewise function.

        Returns:
            CDF values at each x-coordinate.
        """
        # Compute the integral over each interval using the trapezoid rule
        intervals = torch.diff(xs)
        trapezoids = 0.5 * intervals * (ys[:-1] + ys[1:])  # Partial areas

        # Cumulative sum to build the CDF
        cdf = torch.cat([torch.zeros(1, device=xs.device), torch.cumsum(trapezoids, dim=0)])

        # Normalize the CDF to ensure it goes from 0 to 1
        cdf = cdf / (cdf[-1] + 1e-10)

        return cdf

    def sample(self, sample_shape: torch.Size | tuple[int, ...]) -> Tensor:
        """Sample from the piecewise linear distribution.

        Args:
            sample_shape: Shape of samples to generate.

        Returns:
            Samples tensor of shape (sample_shape[0], C, F, L, R).
        """
        num_samples = sample_shape[0] if isinstance(sample_shape, torch.Size) else sample_shape[0]
        samples = torch.empty(
            (
                num_samples,
                self.num_channels,
                self.num_features,
                self.num_leaves,
                self.num_repetitions,
            ),
            device=self.xs[0][0][0][0].device,
        )

        for i_feature in range(self.num_features):
            for i_channel in range(self.num_channels):
                for i_repetition in range(self.num_repetitions):
                    for i_leaf in range(self.num_leaves):
                        xs_i = self.xs[i_repetition][i_leaf][i_feature][i_channel]
                        ys_i = self.ys[i_repetition][i_leaf][i_feature][i_channel]

                        if self.domains[i_feature].data_type == DataType.DISCRETE:
                            # Sample from a categorical distribution
                            ys_i_wo_tails = ys_i[1:-1]  # Cut off the tail breaks
                            dist = torch.distributions.Categorical(probs=ys_i_wo_tails)
                            samples[:, i_channel, i_feature, i_leaf, i_repetition] = dist.sample(sample_shape)
                        elif self.domains[i_feature].data_type == DataType.CONTINUOUS:
                            # Compute the CDF for this piecewise function
                            cdf = self._compute_cdf(xs_i, ys_i)

                            # Sample from a uniform distribution
                            u = torch.rand(num_samples, device=xs_i.device)

                            # Find the corresponding segment using searchsorted
                            # Ensure contiguous inputs
                            cdf = cdf.contiguous()
                            u = u.contiguous()
                            indices = torch.searchsorted(cdf, u, right=True)

                            # Clamp indices to be within valid range
                            indices = torch.clamp(indices, 1, len(xs_i) - 1)

                            # Perform linear interpolation to get the sample value
                            x0, x1 = xs_i[indices - 1], xs_i[indices]
                            cdf0, cdf1 = cdf[indices - 1], cdf[indices]
                            slope = (x1 - x0) / (cdf1 - cdf0 + 1e-8)  # Avoid division by zero

                            # Compute the sampled value
                            samples[:, i_channel, i_feature, i_leaf, i_repetition] = x0 + slope * (u - cdf0)
                        else:
                            raise ValueError(f"Unknown data type: {self.domains[i_feature].data_type}")

        return samples

    @property
    def mode(self) -> Tensor:
        """Compute the mode of the distribution.

        Returns:
            Modes tensor of shape (C, F, L, R).
        """
        modes = torch.empty(
            (
                self.num_channels,
                self.num_features,
                self.num_leaves,
                self.num_repetitions,
            ),
            device=self.xs[0][0][0][0].device,
        )

        for i_feature in range(self.num_features):
            for i_channel in range(self.num_channels):
                for i_repetition in range(self.num_repetitions):
                    for i_leaf in range(self.num_leaves):
                        xs_i = self.xs[i_repetition][i_leaf][i_feature][i_channel]
                        ys_i = self.ys[i_repetition][i_leaf][i_feature][i_channel]

                        # Find the mode (the x value with the highest PDF value)
                        max_idx = torch.argmax(ys_i)
                        mode_value = xs_i[max_idx]

                        # Store the mode value
                        modes[i_channel, i_feature, i_leaf, i_repetition] = mode_value

        return modes

    def log_prob(self, x: Tensor) -> Tensor:
        """Compute log probabilities for input data.

        Args:
            x: Input tensor of shape (N, C, F, 1, 1) or (N, C, F).

        Returns:
            Log probabilities of shape (N, C, F, L, R).
        """
        # Handle input shapes
        if x.dim() == 5:
            x = x.squeeze(-1).squeeze(-1)

        batch_size = x.shape[0]
        probs = torch.zeros(
            batch_size,
            self.num_channels,
            self.num_features,
            self.num_leaves,
            self.num_repetitions,
            device=x.device,
        )

        # Perform linear interpolation
        for i_feature in range(self.num_features):
            for i_channel in range(self.num_channels):
                for i_repetition in range(self.num_repetitions):
                    for i_leaf in range(self.num_leaves):
                        xs_i = self.xs[i_repetition][i_leaf][i_feature][i_channel]
                        ys_i = self.ys[i_repetition][i_leaf][i_feature][i_channel]
                        ivalues = interp(x[:, i_channel, i_feature], xs_i, ys_i)
                        probs[:, i_channel, i_feature, i_leaf, i_repetition] = ivalues

        # Return the logarithm of probabilities
        logprobs = torch.log(probs + 1e-10)
        logprobs = torch.clamp(logprobs, min=-300.0)
        return logprobs


[docs] class PiecewiseLinear(LeafModule): """Piecewise linear leaf distribution module. First constructs histograms from the data using K-means clustering, then approximates the histograms with piecewise linear functions. This leaf requires initialization with data via the `initialize()` method before it can be used for inference or sampling. Attributes: alpha: Laplace smoothing parameter. xs: Nested list of x-coordinates for piecewise linear functions. ys: Nested list of y-coordinates (densities) for piecewise linear functions. domains: List of Domain objects describing each feature. is_initialized: Whether the distribution has been initialized with data. """
[docs] def __init__( self, scope: Scope | int | List[int], out_channels: int = 1, num_repetitions: int = 1, alpha: float = 0.0, ): """Initialize PiecewiseLinear leaf module. Args: scope: Variable scope (Scope, int, or list[int]). out_channels: Number of output channels (clusters via K-means). num_repetitions: Number of repetitions. alpha: Laplace smoothing parameter (default 0.0). """ super().__init__( scope=scope, out_channels=out_channels, num_repetitions=num_repetitions, ) if alpha < 0: raise ValueError(f"alpha must be non-negative, got {alpha}") self.alpha = alpha # These will be set during initialization self.xs: Optional[List] = None self.ys: Optional[List] = None self.domains: Optional[List[Domain]] = None self.is_initialized = False # Register a dummy parameter so device detection works self.register_buffer("_device_buffer", torch.zeros(1))
@property def _torch_distribution_class(self): """PiecewiseLinear uses a custom distribution, not a torch.distributions class.""" return None @property def _supported_value(self) -> float: """Returns a value in the support of the distribution.""" return 0.0 @property def distribution(self) -> PiecewiseLinearDist: """Returns the underlying PiecewiseLinearDist object. Raises: ValueError: If the distribution has not been initialized. """ if not self.is_initialized: raise ValueError( "PiecewiseLinear leaf has not been initialized. " "Call initialize(data, domains) first." ) return PiecewiseLinearDist(self.xs, self.ys, self.domains) # type: ignore[arg-type] @property def mode(self) -> Tensor: """Return distribution mode. Returns: Mode of the distribution. """ return self.distribution.mode
[docs] def params(self) -> dict: """Returns the parameters of the distribution. For PiecewiseLinear, returns xs and ys nested lists. """ return {"xs": self.xs, "ys": self.ys}
def _compute_parameter_estimates(self, data: Tensor, weights: Tensor, bias_correction: bool) -> dict: """Not implemented for PiecewiseLinear - use initialize() instead.""" raise NotImplementedError("PiecewiseLinear does not support MLE. Use initialize() instead.")
[docs] def initialize(self, data: Tensor, domains: List[Domain]) -> None: """Initialize the piecewise linear distribution with data. Uses K-means clustering to create multiple distributions per leaf, then constructs histograms and approximates them with piecewise linear functions. Args: data: Training data tensor of shape (N, F) where N is batch size and F is the number of features. domains: List of Domain objects, one per feature. Raises: ValueError: If data shape doesn't match scope. """ try: from fast_pytorch_kmeans import KMeans except ImportError as e: raise OptionalDependencyError( "fast_pytorch_kmeans required for PiecewiseLinear. " "Install with: pip install fast-pytorch-kmeans" ) from e logger.info(f"Initializing PiecewiseLinear with data shape {data.shape}") # Validate input num_features = len(self.scope.query) if data.shape[1] != num_features: raise ValueError(f"Data has {data.shape[1]} features but scope has {num_features}") if len(domains) != num_features: raise ValueError(f"Got {len(domains)} domains but scope has {num_features} features") self.domains = domains device = data.device # Parameters stored as nested lists [R][L][F][C] xs = [] ys = [] num_leaves = self.out_shape.channels for i_repetition in range(self.out_shape.repetitions): xs_leaves = [] ys_leaves = [] # Cluster data into num_leaves clusters if num_leaves > 1: kmeans = KMeans(n_clusters=num_leaves, mode="euclidean", verbose=0, init_method="random") kmeans.fit(data.float()) cluster_idxs = kmeans.max_sim(a=data.float(), b=kmeans.centroids)[1] else: cluster_idxs = torch.zeros(data.shape[0], dtype=torch.long, device=device) for cluster_idx in range(num_leaves): # Select data for this cluster mask = cluster_idxs == cluster_idx cluster_data = data[mask] xs_features = [] ys_features = [] for i_feature in range(num_features): xs_channels = [] ys_channels = [] # For PiecewiseLinear, we use a single "channel" per feature # (the reference used num_channels but SPFlow uses out_channels for leaves) data_subset = cluster_data[:, i_feature].float() if self.domains[i_feature].data_type == DataType.DISCRETE: # Edges are the discrete values mids = torch.as_tensor( self.domains[i_feature].values, device=device, dtype=torch.float32 ) # Add a break at the end breaks = torch.cat([mids, mids[-1:].add(1)]) if data_subset.shape[0] == 0: # If no data in cluster, use uniform densities = torch.ones(len(mids), device=device) / len(mids) else: # Compute histogram densities densities = torch.histogram( data_subset.cpu(), bins=breaks.cpu(), density=True ).hist.to(device) elif self.domains[i_feature].data_type == DataType.CONTINUOUS: # Find histogram bins using automatic bin width if data_subset.numel() > 0: bins, _ = get_bin_edges_torch(data_subset) else: # Fallback for empty data bins = torch.linspace( self.domains[i_feature].min or 0, self.domains[i_feature].max or 1, 11, device=device, ) # Construct histogram if data_subset.numel() > 0: densities = torch.histogram( data_subset.cpu(), bins=bins.cpu(), density=True ).hist.to(device) else: densities = torch.ones(len(bins) - 1, device=device) / (len(bins) - 1) breaks = bins mids = ((breaks + torch.roll(breaks, shifts=-1, dims=0)) / 2)[:-1] else: raise ValueError(f"Unknown data type: {domains[i_feature].data_type}") # Apply optional Laplace smoothing if self.alpha > 0: n_samples = data_subset.shape[0] n_bins = len(breaks) - 1 counts = densities * n_samples densities = (counts + self.alpha) / (n_samples + n_bins * self.alpha) # Add tail breaks to start and end if self.domains[i_feature].data_type == DataType.DISCRETE: tail_width = 1 x = [b.item() for b in breaks[:-1]] x = [x[0] - tail_width] + x + [x[-1] + tail_width] elif self.domains[i_feature].data_type == DataType.CONTINUOUS: EPS = 1e-8 x = ( [breaks[0].item() - EPS] + [b0.item() + (b1.item() - b0.item()) / 2 for (b0, b1) in pairwise(breaks)] + [breaks[-1].item() + EPS] ) else: raise ValueError( f"Unknown data type in tail break construction: {self.domains[i_feature].data_type}" ) # Add density 0 at start and end tail breaks y = [0.0] + [d.item() for d in densities] + [0.0] # Construct tensors x = torch.tensor(x, device=device, dtype=torch.float32) y = torch.tensor(y, device=device, dtype=torch.float32) # Compute AUC using the trapeziod rule auc = torch.trapezoid(y=y, x=x) # Normalize y to sum to 1 using AUC if auc > 0: y = y / auc xs_channels.append(x) ys_channels.append(y) xs_features.append(xs_channels) ys_features.append(ys_channels) xs_leaves.append(xs_features) ys_leaves.append(ys_features) xs.append(xs_leaves) ys.append(ys_leaves) self.xs = xs self.ys = ys self.is_initialized = True logger.info("PiecewiseLinear initialization complete")
[docs] def reset(self) -> None: """Reset the distribution to uninitialized state.""" self.is_initialized = False self.xs = None self.ys = None self.domains = None
[docs] def log_likelihood( self, data: Tensor, cache: Cache | None = None, ) -> Tensor: """Compute log-likelihoods for input data. Args: data: Input data tensor of shape (N, F). cache: Optional cache dictionary. Returns: Log-likelihood tensor. """ if not self.is_initialized: raise ValueError( "PiecewiseLinear leaf has not been initialized. " "Call initialize(data, domains) first." ) if data.dim() != 2: raise ValueError(f"Data must be 2-dimensional (batch, num_features), got shape {data.shape}.") # Get scope-relevant data data_q = data[:, self.scope.query] # Handle marginalization marg_mask = torch.isnan(data_q) has_marginalizations = marg_mask.any() if has_marginalizations: data_q = data_q.clone() data_q[marg_mask] = self._supported_value # Unsqueeze to add channel dimension data_q = data_q.unsqueeze(1) # [N, 1, F] # Compute log probabilities dist = self.distribution log_prob = dist.log_prob(data_q) # Marginalize entries if has_marginalizations: # Expand mask to match log_prob shape marg_mask_expanded = marg_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) marg_mask_expanded = torch.broadcast_to(marg_mask_expanded, log_prob.shape) log_prob[marg_mask_expanded] = 0.0 return log_prob
[docs] def sample( self, num_samples: int | None = None, data: Tensor | None = None, is_mpe: bool = False, cache: Cache | None = None, sampling_ctx: Optional[SamplingContext] = None, ) -> Tensor: """Sample from the piecewise linear distribution. Args: num_samples: Number of samples to generate. data: Optional evidence tensor. is_mpe: Perform MPE (mode) instead of sampling. cache: Optional cache dictionary. sampling_ctx: Optional sampling context. Returns: Sampled data tensor. """ if not self.is_initialized: raise ValueError( "PiecewiseLinear leaf has not been initialized. " "Call initialize(data, domains) first." ) # Prepare data tensor data = self._prepare_sample_data(num_samples, data) sampling_ctx = init_default_sampling_context(sampling_ctx, data.shape[0]) out_of_scope = list(filter(lambda x: x not in self.scope.query, range(data.shape[1]))) marg_mask = torch.isnan(data) marg_mask[:, out_of_scope] = False # Mask that tells us which feature at which sample is relevant samples_mask = marg_mask samples_mask[:, self.scope.query] &= sampling_ctx.mask # Count number of samples to draw instance_mask = samples_mask.sum(1) > 0 n_samples = instance_mask.sum() if sampling_ctx.repetition_idx is None: if self.out_shape.repetitions > 1: raise ValueError( "Repetition index must be provided in sampling context for leaves with multiple repetitions." ) else: sampling_ctx.repetition_idx = torch.zeros(data.shape[0], dtype=torch.long, device=data.device) dist = self.distribution n_samples_int = int(n_samples.item()) if is_mpe: samples = dist.mode.unsqueeze(0) samples = samples.repeat(n_samples_int, 1, 1, 1, 1).detach() else: samples = dist.sample((n_samples_int,)) # Handle repetition index if samples.ndim == 5: repetition_idx = sampling_ctx.repetition_idx[instance_mask] r_idxs = repetition_idx.view(-1, 1, 1, 1, 1).expand( -1, samples.shape[1], samples.shape[2], samples.shape[3], -1 ) samples = torch.gather(samples, dim=-1, index=r_idxs).squeeze(-1) # Handle channel index - gather on leaves dimension (dim=3) # samples shape after repetition handling: (N, C=1, F, L) if self.out_shape.channels == 1: sampling_ctx.channel_index.zero_() # c_idxs needs shape (N, 1, F, 1) to gather on dim=3 c_idxs = sampling_ctx.channel_index[instance_mask] # (N,) c_idxs = c_idxs.view(-1, 1, 1, 1).expand(-1, 1, samples.shape[2], 1) # (N, 1, F, 1) samples = samples.gather(dim=3, index=c_idxs).squeeze(3) # (N, 1, F) # Squeeze channel dimension samples = samples.squeeze(1) # (N, F) # Update data with samples row_indices = instance_mask.nonzero(as_tuple=True)[0] scope_idx = torch.tensor(self.scope.query, dtype=torch.long, device=data.device) rows = row_indices.unsqueeze(1).expand(-1, len(scope_idx)) cols = scope_idx.unsqueeze(0).expand(n_samples_int, -1) mask_subset = samples_mask[instance_mask][:, self.scope.query] data[rows[mask_subset], cols[mask_subset]] = samples[mask_subset].to(data.dtype) return data