"""scikit-learn compatible wrappers for SPFlow models.
These wrappers are optional: SPFlow can be used without scikit-learn installed.
Importing this module does not require scikit-learn, but instantiating the
estimators will.
"""
from __future__ import annotations
from dataclasses import dataclass
import math
from typing import Any, Literal
import numpy as np
import torch
from einops import rearrange, reduce
from spflow.exceptions import (
InvalidParameterError,
InvalidTypeError,
OptionalDependencyError,
UnsupportedOperationError,
)
from spflow.learn.learn_spn import learn_spn
from spflow.learn.prometheus import learn_prometheus
from spflow.meta.data.scope import Scope
from spflow.modules.leaves.normal import Normal
from spflow.modules.module import Module
try: # pragma: no cover (covered via importorskip tests)
from sklearn.base import BaseEstimator, ClassifierMixin, DensityMixin
from sklearn.utils.validation import check_is_fitted
_SKLEARN_AVAILABLE = True
except ModuleNotFoundError: # pragma: no cover
BaseEstimator = object # type: ignore[assignment]
ClassifierMixin = object # type: ignore[assignment]
DensityMixin = object # type: ignore[assignment]
check_is_fitted = None # type: ignore[assignment]
_SKLEARN_AVAILABLE = False
def _require_sklearn() -> None:
if not _SKLEARN_AVAILABLE:
raise OptionalDependencyError(
"scikit-learn is required for SPFlow sklearn integration. "
"Install it with `pip install scikit-learn` (or `pip install spflow[sklearn]`)."
)
def _as_2d_numpy(x: Any) -> np.ndarray:
"""Convert array-like to a 2D numpy array."""
arr = np.asarray(x)
if arr.ndim == 1:
return arr.reshape(-1, 1)
if arr.ndim != 2:
raise InvalidParameterError(f"Expected 2D array-like input, got shape {arr.shape}.")
return arr
def _torch_dtype_from_str(dtype: str | None) -> torch.dtype | None:
if dtype is None:
return None
if dtype == "float32":
return torch.float32
if dtype == "float64":
return torch.float64
raise InvalidParameterError(f"Unknown dtype '{dtype}'. Use 'float32', 'float64', or None.")
def _default_torch_device() -> torch.device:
"""Return the active default torch device, falling back to CPU when unavailable."""
get_default_device = getattr(torch, "get_default_device", None)
if get_default_device is None:
return torch.device("cpu")
return torch.device(get_default_device())
def _reduce_log_likelihood(
ll: torch.Tensor,
*,
channel_agg: Literal["logmeanexp", "logsumexp", "first"],
repetition_agg: Literal["logmeanexp", "logsumexp", "first"],
) -> torch.Tensor:
"""Reduce SPFlow log-likelihood tensor to per-sample scalar log-likelihoods.
SPFlow modules typically return log-likelihoods shaped like:
(batch, features, channels, repetitions)
This function:
- sums over features (log-space product),
- aggregates repetitions and channels (mixture-like reduction), and
- returns a 1D tensor of shape (batch,).
"""
if ll.dim() == 2:
ll = rearrange(ll, "b f -> b f 1 1")
elif ll.dim() == 3:
ll = rearrange(ll, "b f c -> b f c 1")
elif ll.dim() != 4:
raise InvalidParameterError(f"Unexpected log-likelihood shape {tuple(ll.shape)}.")
if ll.shape[0] == 0:
return ll.new_zeros((0,))
ll = reduce(ll, "b f c r -> b c r", "sum")
def reduce_over(t: torch.Tensor, dim: int, method: str) -> torch.Tensor:
if t.shape[dim] == 1:
return t.squeeze(dim)
if method == "first":
return t.select(dim, 0)
if method == "logsumexp":
return torch.logsumexp(t, dim=dim)
if method == "logmeanexp":
return torch.logsumexp(t, dim=dim) - math.log(t.shape[dim])
raise InvalidParameterError(f"Unknown reduction method '{method}'.")
ll = reduce_over(ll, dim=-1, method=repetition_agg) # (B, C)
ll = reduce_over(ll, dim=-1, method=channel_agg) # (B,)
return ll
@dataclass(frozen=True)
class _StructureLearnerSpec:
name: Literal["learn_spn", "prometheus"]
kwargs: dict[str, Any] | None
[docs]
class SPFlowDensityEstimator(BaseEstimator, DensityMixin):
"""scikit-learn compatible density estimator for SPFlow models.
Supports two workflows:
- **Structure learning**: learn a model from data via `learn_spn` or `learn_prometheus`.
- **Parameter fitting**: fit parameters of a provided SPFlow model via MLE.
Args:
model: Optional SPFlow model to fit and use for scoring/sampling.
structure_learner: "learn_spn" or "prometheus". Only used when `model` is None.
structure_learner_kwargs: Keyword arguments forwarded to the structure learner.
fit_params: If True and `model` is provided, run MLE (`maximum_likelihood_estimation`) in `fit`.
leaf: Leaf family used when learning structure and `model` is None. Currently supports "normal".
leaf_out_channels: Output channels for the leaf module template (passed to `Normal`).
min_instances_slice: Stopping criterion for structure learning (forwarded if not overridden).
min_features_slice: Stopping criterion for structure learning (forwarded if not overridden).
device: Torch device string (e.g., "cpu", "cuda"). If None, uses model device or the
active PyTorch default device.
dtype: Torch dtype string ("float32", "float64") for inputs.
channel_agg: How to aggregate multiple output channels into a scalar log-likelihood.
repetition_agg: How to aggregate multiple repetitions into a scalar log-likelihood.
"""
def __init__(
self,
model: Module | None = None,
*,
structure_learner: Literal["learn_spn", "prometheus"] = "learn_spn",
structure_learner_kwargs: dict[str, Any] | None = None,
fit_params: bool = True,
leaf: Literal["normal"] = "normal",
leaf_out_channels: int = 1,
min_instances_slice: int = 100,
min_features_slice: int = 2,
device: str | None = None,
dtype: Literal["float32", "float64"] | None = None,
channel_agg: Literal["logmeanexp", "logsumexp", "first"] = "logmeanexp",
repetition_agg: Literal["logmeanexp", "logsumexp", "first"] = "logmeanexp",
) -> None:
_require_sklearn()
self.model = model
self.structure_learner = structure_learner
self.structure_learner_kwargs = structure_learner_kwargs
self.fit_params = fit_params
self.leaf = leaf
self.leaf_out_channels = leaf_out_channels
self.min_instances_slice = min_instances_slice
self.min_features_slice = min_features_slice
self.device = device
self.dtype = dtype
self.channel_agg = channel_agg
self.repetition_agg = repetition_agg
def _device(self) -> torch.device:
if self.device is not None:
return torch.device(self.device)
if hasattr(self, "model_") and isinstance(self.model_, Module):
return self.model_.device
if self.model is not None:
return self.model.device
return _default_torch_device()
def _to_tensor(self, x: Any) -> torch.Tensor:
arr = _as_2d_numpy(x)
dtype = _torch_dtype_from_str(self.dtype)
return torch.as_tensor(arr, dtype=dtype, device=self._device())
def _leaf_template(self, n_features: int) -> Any:
if self.leaf != "normal":
raise InvalidParameterError(f"Unknown leaf '{self.leaf}'.")
return Normal(scope=Scope(list(range(n_features))), out_channels=self.leaf_out_channels).to(
self._device()
)
def _structure_spec(self) -> _StructureLearnerSpec:
if self.structure_learner not in ("learn_spn", "prometheus"):
raise InvalidParameterError(
"structure_learner must be 'learn_spn' or 'prometheus', " f"got '{self.structure_learner}'."
)
return _StructureLearnerSpec(name=self.structure_learner, kwargs=self.structure_learner_kwargs)
[docs]
def fit(self, X: Any, y: Any | None = None) -> "SPFlowDensityEstimator":
"""Fit a density model.
Args:
X: Array-like of shape (n_samples, n_features).
y: Ignored. Present for scikit-learn compatibility.
"""
del y
x_tensor = self._to_tensor(X)
self.n_features_in_ = int(x_tensor.shape[1])
if self.model is None:
leaf_modules = self._leaf_template(self.n_features_in_)
spec = self._structure_spec()
learner_kwargs: dict[str, Any] = {
"out_channels": 1,
"min_instances_slice": self.min_instances_slice,
"min_features_slice": self.min_features_slice,
}
if spec.kwargs:
learner_kwargs.update(spec.kwargs)
if spec.name == "learn_spn":
self.model_ = learn_spn(x_tensor, leaf_modules=leaf_modules, **learner_kwargs)
else:
self.model_ = learn_prometheus(x_tensor, leaf_modules=leaf_modules, **learner_kwargs)
else:
if not isinstance(self.model, Module):
raise InvalidTypeError(
f"model must be a spflow.modules.module.Module, got {type(self.model)}."
)
self.model_ = self.model
if self.fit_params:
mle = getattr(self.model_, "maximum_likelihood_estimation", None)
if mle is None:
raise InvalidParameterError(
"fit_params=True requires a model exposing maximum_likelihood_estimation "
"(typically leaf modules). "
"For general circuit models, use spflow.learn.expectation_maximization(...)."
)
try:
mle(x_tensor)
except UnsupportedOperationError as exc:
raise InvalidParameterError(
"fit_params=True requires a model exposing maximum_likelihood_estimation "
"(typically leaf modules). "
"For general circuit models, use spflow.learn.expectation_maximization(...)."
) from exc
return self
[docs]
def score_samples(self, X: Any) -> np.ndarray:
"""Compute per-sample log-likelihood under the fitted model."""
check_is_fitted(self, attributes=["model_"])
x_tensor = self._to_tensor(X)
with torch.no_grad():
ll = self.model_.log_likelihood(x_tensor)
reduced = _reduce_log_likelihood(
ll,
channel_agg=self.channel_agg,
repetition_agg=self.repetition_agg,
)
return reduced.detach().cpu().numpy()
[docs]
def sample(self, n_samples: int = 1, *, random_state: int | None = None) -> np.ndarray:
"""Generate samples from the fitted model.
Args:
n_samples: Number of samples to draw.
random_state: Optional seed for deterministic sampling.
"""
check_is_fitted(self, attributes=["model_"])
if not isinstance(n_samples, int) or n_samples < 1:
raise InvalidParameterError(f"n_samples must be a positive integer, got {n_samples}.")
if random_state is not None and not isinstance(random_state, (int, np.integer)):
raise InvalidTypeError(f"random_state must be an int or None, got {type(random_state)}.")
seed = int(random_state) if random_state is not None else None
device = self._device()
cuda_devices: list[int] = []
if device.type == "cuda":
cuda_devices = [device.index or 0]
with torch.random.fork_rng(devices=cuda_devices):
if seed is not None:
torch.manual_seed(seed)
with torch.no_grad():
samples = self.model_.sample(num_samples=n_samples)
return samples.detach().cpu().numpy()
[docs]
class SPFlowClassifier(BaseEstimator, ClassifierMixin):
"""scikit-learn compatible classifier wrapper for SPFlow classifiers.
This wrapper delegates to a provided SPFlow model that implements
`predict_proba(torch.Tensor) -> torch.Tensor`.
"""
def __init__(
self,
model: Any,
*,
device: str | None = None,
dtype: Literal["float32", "float64"] | None = None,
) -> None:
_require_sklearn()
self.model = model
self.device = device
self.dtype = dtype
def _device(self) -> torch.device:
if self.device is not None:
return torch.device(self.device)
if hasattr(self.model, "device"):
return torch.device(getattr(self.model, "device"))
return _default_torch_device()
def _to_tensor(self, x: Any) -> torch.Tensor:
arr = _as_2d_numpy(x)
dtype = _torch_dtype_from_str(self.dtype)
return torch.as_tensor(arr, dtype=dtype, device=self._device())
[docs]
def fit(self, X: Any, y: Any) -> "SPFlowClassifier":
"""Store class labels for sklearn compatibility."""
del X
classes = np.unique(np.asarray(y))
self.classes_ = classes
return self
[docs]
def predict_proba(self, X: Any) -> np.ndarray:
check_is_fitted(self, attributes=["classes_"])
x_tensor = self._to_tensor(X)
with torch.no_grad():
probs = self.model.predict_proba(x_tensor)
return probs.detach().cpu().numpy()
[docs]
def predict(self, X: Any) -> np.ndarray:
"""Predict class labels using argmax over predicted probabilities."""
check_is_fitted(self, attributes=["classes_"])
probs = self.predict_proba(X)
indices = np.argmax(probs, axis=1)
return np.asarray(self.classes_)[indices]