Source code for spflow.measures.information_theory

from __future__ import annotations

from collections.abc import Iterable

import torch
from torch import Tensor

from spflow.exceptions import InvalidParameterError, UnsupportedOperationError
from spflow.measures._utils import (
    as_scope,
    fork_rng,
    infer_discrete_domains,
    reduce_log_likelihood,
)
from spflow.meta.data.scope import Scope
from spflow.modules.module import Module


[docs] def entropy( model: Module, scope: Scope | int | Iterable[int], *, method: str = "mc", num_samples: int = 10_000, seed: int | None = None, channel_agg: str = "logmeanexp", repetition_agg: str = "logmeanexp", ) -> Tensor: """Estimate the entropy H(X) (in nats) for a subset of variables. The returned value is in nats (natural logarithm base), consistent with SPFlow log-likelihood conventions. Args: model: SPFlow probabilistic circuit. scope: Variables X to compute entropy for. method: "mc" (Monte Carlo) or "exact" (enumeration for tiny discrete domains). num_samples: Number of samples for Monte Carlo estimation. seed: Optional seed for best-effort deterministic sampling. channel_agg: How to aggregate multiple channels ("logmeanexp", "logsumexp", "first"). repetition_agg: How to aggregate multiple repetitions ("logmeanexp", "logsumexp", "first"). Returns: Scalar tensor containing H(X) in nats. """ scope = as_scope(scope) if scope.empty(): raise InvalidParameterError("entropy scope must be non-empty.") if method not in ("mc", "exact"): raise InvalidParameterError(f"Unknown method '{method}'. Use 'mc' or 'exact'.") if method == "exact": domains = infer_discrete_domains(model, scope) rvs = list(scope.query) values = [domains[rv] for rv in rvs] grid = torch.cartesian_prod(*values) # (N, |rvs|) if grid.dim() == 1: grid = grid.unsqueeze(1) d = len(model.scope.query) data = torch.full((grid.shape[0], d), torch.nan, device=model.device, dtype=torch.get_default_dtype()) for j, rv in enumerate(rvs): data[:, rv] = grid[:, j] ll = reduce_log_likelihood( model.log_likelihood(data), channel_agg=channel_agg, repetition_agg=repetition_agg, ) # (N,) mask = torch.isfinite(ll) p_log_p = torch.zeros_like(ll) p_log_p[mask] = torch.exp(ll[mask]) * ll[mask] return -p_log_p.sum() if num_samples < 1: raise InvalidParameterError("num_samples must be >= 1 for Monte Carlo entropy.") with fork_rng(seed, model.device) as _: if seed is not None: torch.manual_seed(seed) samples = model.sample(num_samples=num_samples) d = samples.shape[1] evidence = torch.full((num_samples, d), torch.nan, device=samples.device, dtype=samples.dtype) evidence[:, list(scope.query)] = samples[:, list(scope.query)] ll = reduce_log_likelihood( model.log_likelihood(evidence), channel_agg=channel_agg, repetition_agg=repetition_agg, ) return -ll.mean()
[docs] def mutual_information( model: Module, x_scope: Scope | int | Iterable[int], y_scope: Scope | int | Iterable[int], *, method: str = "mc", num_samples: int = 10_000, seed: int | None = None, channel_agg: str = "logmeanexp", repetition_agg: str = "logmeanexp", ) -> Tensor: """Estimate mutual information I(X;Y) (in nats).""" x_scope = as_scope(x_scope) y_scope = as_scope(y_scope) if set(x_scope.query).intersection(y_scope.query): raise InvalidParameterError("x_scope and y_scope must be disjoint for mutual_information.") if method == "exact": h_x = entropy( model, x_scope, method="exact", channel_agg=channel_agg, repetition_agg=repetition_agg, ) h_y = entropy( model, y_scope, method="exact", channel_agg=channel_agg, repetition_agg=repetition_agg, ) h_xy = entropy( model, Scope(list(x_scope.query) + list(y_scope.query)), method="exact", channel_agg=channel_agg, repetition_agg=repetition_agg, ) return h_x + h_y - h_xy if method != "mc": raise InvalidParameterError(f"Unknown method '{method}'. Use 'mc' or 'exact'.") if num_samples < 1: raise InvalidParameterError("num_samples must be >= 1 for Monte Carlo mutual_information.") with fork_rng(seed, model.device) as _: if seed is not None: torch.manual_seed(seed) samples = model.sample(num_samples=num_samples) d = samples.shape[1] x_rvs = list(x_scope.query) y_rvs = list(y_scope.query) xy_rvs = x_rvs + y_rvs def ll_for(rvs: list[int]) -> Tensor: ev = torch.full((num_samples, d), torch.nan, device=samples.device, dtype=samples.dtype) ev[:, rvs] = samples[:, rvs] return reduce_log_likelihood( model.log_likelihood(ev), channel_agg=channel_agg, repetition_agg=repetition_agg, ) ll_xy = ll_for(xy_rvs) ll_x = ll_for(x_rvs) ll_y = ll_for(y_rvs) return (ll_xy - ll_x - ll_y).mean()
[docs] def conditional_mutual_information( model: Module, x_scope: Scope | int | Iterable[int], y_scope: Scope | int | Iterable[int], z_scope: Scope | int | Iterable[int], *, method: str = "mc", num_samples: int = 10_000, seed: int | None = None, channel_agg: str = "logmeanexp", repetition_agg: str = "logmeanexp", ) -> Tensor: """Estimate conditional mutual information I(X;Y|Z) (in nats).""" x_scope = as_scope(x_scope) y_scope = as_scope(y_scope) z_scope = as_scope(z_scope) all_rvs = list(x_scope.query) + list(y_scope.query) + list(z_scope.query) if len(set(all_rvs)) != len(all_rvs): raise InvalidParameterError("x_scope, y_scope, and z_scope must be pairwise disjoint.") if method == "exact": h_z = entropy( model, z_scope, method="exact", channel_agg=channel_agg, repetition_agg=repetition_agg, ) h_xz = entropy( model, Scope(list(x_scope.query) + list(z_scope.query)), method="exact", channel_agg=channel_agg, repetition_agg=repetition_agg, ) h_yz = entropy( model, Scope(list(y_scope.query) + list(z_scope.query)), method="exact", channel_agg=channel_agg, repetition_agg=repetition_agg, ) h_xyz = entropy( model, Scope(list(x_scope.query) + list(y_scope.query) + list(z_scope.query)), method="exact", channel_agg=channel_agg, repetition_agg=repetition_agg, ) return h_xz + h_yz - h_z - h_xyz if method != "mc": raise InvalidParameterError(f"Unknown method '{method}'. Use 'mc' or 'exact'.") if num_samples < 1: raise InvalidParameterError( "num_samples must be >= 1 for Monte Carlo conditional_mutual_information." ) with fork_rng(seed, model.device) as _: if seed is not None: torch.manual_seed(seed) samples = model.sample(num_samples=num_samples) d = samples.shape[1] x_rvs = list(x_scope.query) y_rvs = list(y_scope.query) z_rvs = list(z_scope.query) xyz_rvs = x_rvs + y_rvs + z_rvs xz_rvs = x_rvs + z_rvs yz_rvs = y_rvs + z_rvs def ll_for(rvs: list[int]) -> Tensor: ev = torch.full((num_samples, d), torch.nan, device=samples.device, dtype=samples.dtype) ev[:, rvs] = samples[:, rvs] return reduce_log_likelihood( model.log_likelihood(ev), channel_agg=channel_agg, repetition_agg=repetition_agg, ) ll_xyz = ll_for(xyz_rvs) ll_z = ll_for(z_rvs) ll_xz = ll_for(xz_rvs) ll_yz = ll_for(yz_rvs) # I(X;Y|Z) = E[log p(x,y,z) + log p(z) - log p(x,z) - log p(y,z)] return (ll_xyz + ll_z - ll_xz - ll_yz).mean()
__all__ = [ "entropy", "mutual_information", "conditional_mutual_information", ]