"""Contains the LearnSPN structure and parameter learner for SPFlow in the ``base`` backend."""
from collections.abc import Callable
from functools import partial
from itertools import combinations
from typing import Any
import numpy as np
import torch
from fast_pytorch_kmeans import KMeans
from networkx import connected_components as ccnp, from_numpy_array
from spflow.meta.data.scope import Scope
from spflow.modules.module import Module
from spflow.modules.leaves.leaf import LeafModule
from spflow.modules.ops.cat import Cat
from spflow.modules.products import Product
from spflow.modules.sums import Sum
from spflow.utils.rdc import rdc
def prune_sums(node):
"""Prune unnecessary sum nodes from a probabilistic circuit.
Recursively traverses the circuit and removes redundant sum nodes by flattening
nested sum-cat structures and merging weights. Reduces circuit complexity while
preserving the probability distribution.
Args:
node: Root node of the circuit to prune. Can be any module type,
but pruning only affects Sum nodes.
Returns:
None: Modifies the circuit in-place.
"""
if isinstance(node, Sum):
child = node.inputs
new_children = []
new_weights = []
if isinstance(child, Cat):
# prune if all children of the cat module are sums
all_sums = all(isinstance(c, Sum) for c in child.inputs)
if all_sums:
for j, c in enumerate(child.inputs):
new_children.append(c.inputs)
new_weights.append(c.weights)
if len(new_children) != 0:
# if we have new children, we need to update the weights
current_weights = node.weights
updated_weights = []
for i in range(len(new_weights)):
updated_weights.append(
new_weights[i] * current_weights[:, i, :].unsqueeze(1),
)
updated_weights = torch.concatenate(updated_weights, dim=1)
all_cat = all(isinstance(c, Cat) for c in new_children)
if all_cat:
# if cat(cat) -> cat
node.inputs = Cat([input_elem for c in new_children for input_elem in c.inputs], dim=2)
else:
node.inputs = Cat(new_children, dim=2)
node.weights_shape = updated_weights.shape
node.weights = updated_weights
# call prune on the same node to prune in case new double sums are formed
prune_sums(node)
else:
# call prune on the inputs if the children are not leaves modules
if not isinstance(node.inputs, LeafModule):
prune_sums(node.inputs)
elif isinstance(node, Product):
if isinstance(node.inputs, Cat):
prune_sums(node.inputs)
elif node.inputs is not None and not isinstance(node.inputs, LeafModule):
prune_sums(node.inputs)
elif isinstance(node, Cat):
for child in node.inputs:
if not isinstance(child, LeafModule):
prune_sums(child)
def adapt_product_inputs(inputs: list[Module], leaf_oc, sum_oc) -> list[Module]:
ref_oc = leaf_oc if leaf_oc > sum_oc else sum_oc
output_modules = []
for m in inputs:
if m.out_shape.channels < ref_oc:
sum_module = Sum(inputs=m, out_channels=ref_oc)
output_modules.append(sum_module)
else:
output_modules.append(m)
return output_modules
def partition_by_rdc(
data: torch.Tensor,
threshold: float = 0.3,
preprocessing: Callable | None = None,
) -> torch.Tensor:
"""Performs partitioning using randomized dependence coefficients (RDCs).
Args:
data: Two-dimensional Tensor containing the input data.
Each row corresponds to a sample.
threshold: Floating point value specifying the threshold for independence testing
between two features. Defaults to 0.3.
preprocessing: Optional callable that is called with ``data`` and returns another
Tensor of the same shape. Defaults to None.
Returns:
One-dimensional Tensor with the same number of entries as the number of features
in ``data``. Each integer value indicates the partition the corresponding feature
is assigned to.
"""
# perform optional pre-processing of data
if preprocessing is not None:
partitioning_data = preprocessing(data)
else:
partitioning_data = data
# necessary for the correct precision for the rdc computation
# Save original dtype and reset after computation to avoid contaminating global state
original_dtype = torch.get_default_dtype()
try:
torch.set_default_dtype(torch.float64)
rdcs = torch.eye(data.shape[1], device=partitioning_data.device)
for i, j in combinations(range(partitioning_data.shape[1]), 2):
r = rdc(partitioning_data[:, i], partitioning_data[:, j])
rdcs[j][i] = rdcs[i][j] = r
# create adjacency matrix of features from thresholded rdcs
rdcs[rdcs < threshold] = 0.0
adj_mat = rdcs
partition_ids = torch.zeros(data.shape[1], dtype=torch.int)
for i, c in enumerate((ccnp(from_numpy_array(np.array(adj_mat.cpu().tolist()))))):
partition_ids[list(c)] = i + 1
return partition_ids.to(data.device)
finally:
torch.set_default_dtype(original_dtype)
def cluster_by_kmeans(
data: torch.Tensor,
n_clusters: int = 2,
preprocessing: Callable | None = None,
) -> torch.Tensor:
"""Performs clustering using k-Means.
Args:
data: Two-dimensional Tensor containing the input data.
Each row corresponds to a sample.
n_clusters: Integer value specifying the number of clusters to be used.
Defaults to 2.
preprocessing: Optional callable that is called with ``data`` and returns another
Tensor of the same shape. Defaults to None.
Returns:
One-dimensional Tensor with the same number of entries as the number of samples
in ``data``. Each integer value indicates the cluster the corresponding sample
is assigned to.
"""
# perform optional pre-processing of data
if preprocessing is not None:
clustering_data = preprocessing(data)
else:
clustering_data = data
kmeans = KMeans(n_clusters=n_clusters, mode="euclidean", verbose=1)
data_labels = kmeans.fit_predict(clustering_data)
return data_labels
[docs]
def learn_spn(
data: torch.Tensor,
leaf_modules: list[LeafModule] | LeafModule,
out_channels: int = 1,
min_features_slice: int = 2,
min_instances_slice: int = 100,
scope=None,
clustering_method: str | Callable = "kmeans",
partitioning_method: str | Callable = "rdc",
clustering_args: dict[str, Any] | None = None,
partitioning_args: dict[str, Any] | None = None,
full_data: torch.Tensor | None = None,
) -> Module:
"""LearnSPN structure and parameter learner.
LearnSPN algorithm as described in (Gens & Domingos, 2013): "Learning the Structure of Sum-Product Networks".
Args:
data: Two-dimensional Tensor containing the input data.
Each row corresponds to a sample.
leaf_modules: List of leaf modules or single leaf module to use for learning.
out_channels: Number of output channels. Defaults to 1.
min_features_slice: Minimum number of features required to partition.
Defaults to 2.
min_instances_slice: Minimum number of instances required to cluster.
Defaults to 100.
scope: Scope for the SPN. If None, inferred from leaf_modules.
clustering_method: String or callable specifying the clustering method.
If 'kmeans', k-Means clustering is used. If a callable, it should accept
data and return cluster assignments.
partitioning_method: String or callable specifying the partitioning method.
If 'rdc', randomized dependence coefficients are used. If a callable, it
should accept data and return partition assignments.
clustering_args: Optional dictionary of keyword arguments for clustering method.
partitioning_args: Optional dictionary of keyword arguments for partitioning method.
full_data: Optional full dataset for parameter estimation.
Returns:
A Module representing the learned SPN.
Raises:
ValueError: If arguments are invalid or scopes are not disjoint.
"""
if scope is None:
if isinstance(leaf_modules, list):
if len(leaf_modules) > 1:
if not Scope.all_pairwise_disjoint([module.scope for module in leaf_modules]):
raise ValueError("Leaf modules must have disjoint scopes.")
scope = leaf_modules[0].scope
for leaf in leaf_modules[1:]:
scope = scope.join(leaf.scope)
else:
scope = leaf_modules[0].scope
else:
scope = leaf_modules.scope
leaf_modules = [leaf_modules]
# Verify that all indices in scope are valid for the data
# if len(scope.query) > 0 and max(scope.query) >= data.shape[1]:
# raise ValueError(f"Scope indices {scope.query} exceed data features {data.shape[1]}.")
# available off-the-shelf clustering methods provided by SPFlow
if isinstance(clustering_method, str):
# Randomized Dependence Coefficients (RDCs)
if clustering_method == "kmeans":
clustering_method = cluster_by_kmeans
else:
raise ValueError(f"Value '{clustering_method}' for partitioning method is invalid.")
# available off-the-shelf partitioning methods provided by SPFlow
if isinstance(partitioning_method, str):
# Randomized Dependence Coefficients (RDCs)
if partitioning_method == "rdc":
partitioning_method = partition_by_rdc
else:
raise ValueError(f"Value '{partitioning_method}' for partitioning method is invalid.")
# for convenience, directly bind additional keyword arguments to the methods
if clustering_args is not None:
clustering_method = partial(clustering_method, **clustering_args)
if partitioning_args is not None:
partitioning_method = partial(partitioning_method, **partitioning_args)
if not isinstance(min_instances_slice, int) or min_instances_slice < 2:
raise ValueError(
f"Value for 'min_instances_slice' must be an integer greater than 1, but was: {min_instances_slice}."
)
if not isinstance(min_features_slice, int) or min_features_slice < 2:
raise ValueError(
f"Value for 'min_features_slice' must be an integer greater than 1, but was: {min_features_slice}."
)
def create_partitioned_mv_leaf(scope: Scope, data: torch.Tensor):
"""Create partitioned leaf nodes from scope and data.
Creates leaf distributions by matching scope with available leaf modules and
estimating parameters via maximum likelihood estimation.
Args:
scope: Variable scope defining which variables to create leaves for.
data: Training data for parameter estimation.
Returns:
Union[Product, LeafModule]: Product node for multiple variables,
or single leaf for univariate case.
"""
leaves = []
s = set(scope.query)
for leaf_module in leaf_modules:
leaf_scope = set(leaf_module.scope.query)
scope_inter = s.intersection(leaf_scope)
if len(scope_inter) > 0:
leaf_layer = leaf_module.__class__(
scope=Scope(sorted(scope_inter)), out_channels=leaf_module.out_shape.channels
)
# estimate leaves node parameters from data
leaf_layer.maximum_likelihood_estimation(data)
leaves.append(leaf_layer)
if len(scope.query) > 1:
if len(leaves) == 1:
return Product(inputs=leaves[0])
else:
return Product(leaves)
else:
return leaves[0]
# features does not need to be split any further
if len(scope.query) < min_features_slice:
return create_partitioned_mv_leaf(scope, data)
else:
# select correct data
if not data.shape[0] == 1:
partition_ids = partitioning_method(data[:, scope.query]) # uc
# compute partitions of rvs from partition id labels
partitions = []
if not data.shape[0] == 1:
for partition_id in torch.sort(torch.unique(partition_ids), dim=-1)[0]: # uc
partitions.append(torch.where(partition_ids == partition_id)) # uc
# multiple partition (i.e., data can be partitioned)
if len(partitions) > 1:
product_inputs = []
for partition in partitions:
sub_structure = learn_spn(
data=data,
leaf_modules=leaf_modules,
scope=Scope([scope.query[rv] for rv in partition[0]]),
out_channels=out_channels,
clustering_method=clustering_method,
partitioning_method=partitioning_method,
min_features_slice=min_features_slice,
min_instances_slice=min_instances_slice,
)
product_inputs.append(sub_structure)
leaf_oc = (
leaf_modules[0].out_shape.channels
if isinstance(leaf_modules, list)
else leaf_modules.out_shape.channels
)
adapted_product_inputs = adapt_product_inputs(product_inputs, leaf_oc, out_channels)
return Product(adapted_product_inputs)
else:
# if not enough instances to cluster, create leaves layer (can be set to prevent overfitting too much or to reduce network size)
if data.shape[0] < min_instances_slice:
return create_partitioned_mv_leaf(scope, data)
# cluster data
else:
labels_per_channel = []
# create cluster for each channel
for i in range(out_channels):
labels = clustering_method(data)
labels_per_channel.append(labels)
# non-conditional clusters
if not scope.is_conditional():
sum_vectors = []
# create sum node for each channel
for labels in labels_per_channel:
inputs_per_channel = []
# Recurse for each label
#
# for each cluster, create a substructure
for cluster_id in torch.unique(labels):
sub_structure = learn_spn(
data[labels == cluster_id, :],
leaf_modules=leaf_modules,
scope=scope,
out_channels=out_channels,
clustering_method=clustering_method,
partitioning_method=partitioning_method,
min_features_slice=min_features_slice,
min_instances_slice=min_instances_slice,
)
inputs_per_channel.append(sub_structure)
# compute weights
w = []
for cluster_id in torch.unique(labels):
probs = torch.sum(labels == cluster_id) / data.shape[0]
w.append(probs)
weights = torch.tensor(w).unsqueeze(0).unsqueeze(-1) # shape(1, num_clusters, 1)
if len(inputs_per_channel) == 1:
inputs = inputs_per_channel
else:
inputs = Cat(inputs_per_channel, dim=2)
weights_stack = []
for idx, child in enumerate(inputs_per_channel):
out_c = child.out_shape.channels
weights_stack.append(weights[:, idx, :].repeat(out_c, 1) / out_c)
weights = (torch.cat(weights_stack)).unsqueeze(0).unsqueeze(-1)
sum_vectors.append(Sum(inputs=inputs, weights=weights))
if len(sum_vectors) == 1:
return sum_vectors[0]
else:
return Cat(sum_vectors, dim=2)
# conditional clusters
else:
raise NotImplementedError("Conditional clustering not yet implemented.")
pass
"""
return CondSum(
children=[
learn_spn(
data[labels == cluster_id, :],
feature_ctx,
clustering_method=clustering_method,
partitioning_method=partitioning_method,
min_features_slice=min_features_slice,
min_instances_slice=min_instances_slice,
)
for cluster_id in torch.unique(labels)
],
)
"""