"""Hidden Chow-Liu Trees (HCLT) learner.
This builds a full probabilistic circuit whose structure is derived from a
Chow-Liu tree over the observed variables, and whose hidden states are modeled
via the channel dimension (H hidden categories per observed variable).
"""
from __future__ import annotations
import torch
from torch import Tensor
from spflow.exceptions import InvalidParameterError, ShapeError
from spflow.zoo.hclt.chow_liu import learn_chow_liu_trees_binary, learn_chow_liu_trees_categorical
from spflow.meta import Scope
from spflow.modules.leaves import Bernoulli, Categorical
from spflow.modules.products import ElementwiseProduct
from spflow.modules.sums import Sum
from spflow.modules.module import Module
from spflow.zoo.hclt.topk_mst import Edge
def _edges_to_adjacency(num_nodes: int, edges: list[Edge]) -> list[list[int]]:
adj: list[list[int]] = [[] for _ in range(num_nodes)]
for a, b in edges:
if a < 0 or b < 0 or a >= num_nodes or b >= num_nodes:
raise InvalidParameterError("Edge indices out of range.")
if a == b:
raise InvalidParameterError("Self-edges are not allowed.")
adj[a].append(b)
adj[b].append(a)
return adj
def _build_hclt_from_tree(
*,
num_features: int,
edges: list[Edge],
num_hidden_cats: int,
emission_factory,
init: str,
device: torch.device | None,
dtype: torch.dtype | None,
) -> Module:
if num_hidden_cats < 1:
raise InvalidParameterError("num_hidden_cats must be >= 1.")
if len(edges) != num_features - 1:
raise InvalidParameterError("Tree must have exactly num_features-1 edges.")
adj = _edges_to_adjacency(num_features, edges)
root = 0
def uniform_sum_weights(in_ch: int, out_ch: int) -> Tensor:
w = torch.full((1, in_ch, out_ch, 1), 1.0 / float(in_ch))
if device is not None:
w = w.to(device=device)
if dtype is not None:
w = w.to(dtype=dtype)
return w
def build_subtree(node: int, parent: int) -> Module:
emission = emission_factory(node)
child_msgs: list[Module] = []
for ch in adj[node]:
if ch == parent:
continue
child_sub = build_subtree(ch, node)
weights = None if init != "uniform" else uniform_sum_weights(num_hidden_cats, num_hidden_cats)
trans = Sum(inputs=child_sub) if weights is None else Sum(inputs=child_sub, weights=weights)
child_msgs.append(trans)
if not child_msgs:
return emission
return ElementwiseProduct(inputs=[emission, *child_msgs])
subtree = build_subtree(root, -1)
prior_weights = None if init != "uniform" else uniform_sum_weights(num_hidden_cats, 1)
root_sum = Sum(inputs=subtree) if prior_weights is None else Sum(inputs=subtree, weights=prior_weights)
return root_sum
[docs]
def learn_hclt_binary(
data: Tensor,
*,
num_hidden_cats: int,
num_trees: int = 1,
dropout_prob: float = 0.0,
weights: Tensor | None = None,
pseudocount: float = 1.0,
init: str = "uniform",
device: torch.device | None = None,
dtype: torch.dtype | None = None,
) -> Module:
"""Learn an HCLT circuit from binary data.
Args:
data: (N, F) tensor with values in {0,1} (or bool). Must be complete (no NaNs).
num_hidden_cats: Hidden categories per observed variable.
num_trees: If >1, builds a mixture of HCLTs over the top-k Chow-Liu trees.
dropout_prob: Edge dropout probability for top-k enumeration.
weights: Optional per-sample weights.
pseudocount: MI pseudocount (ChowLiuTrees.jl semantics).
init: "uniform" or "random" (random uses module defaults).
device/dtype: Optional placement overrides for created modules.
"""
if data.dim() != 2:
raise ShapeError(f"data must be 2D (N,F), got shape {tuple(data.shape)}.")
if torch.isnan(data).any():
raise InvalidParameterError("learn_hclt_binary requires complete data (no NaNs).")
if init not in ("uniform", "random"):
raise InvalidParameterError("init must be 'uniform' or 'random'.")
if num_trees < 1:
raise InvalidParameterError("num_trees must be >= 1.")
num_features = int(data.shape[1])
trees = learn_chow_liu_trees_binary(
data,
num_trees=num_trees,
dropout_prob=dropout_prob,
weights=weights,
pseudocount=pseudocount,
)
def emission_factory(var: int) -> Module:
leaf = Bernoulli(scope=Scope([var]), out_channels=num_hidden_cats)
if device is not None:
leaf = leaf.to(device=device)
if dtype is not None:
leaf = leaf.to(dtype=dtype)
return leaf
hclts = [
_build_hclt_from_tree(
num_features=num_features,
edges=edges,
num_hidden_cats=num_hidden_cats,
emission_factory=emission_factory,
init=init,
device=device,
dtype=dtype,
)
for edges in trees
]
if len(hclts) == 1:
return hclts[0]
# Mixture over top-k HCLTs (learnable mixture weights).
mix_weights = None
if init == "uniform":
mix_weights = torch.full((1, len(hclts), 1, 1), 1.0 / float(len(hclts)))
if device is not None:
mix_weights = mix_weights.to(device=device)
if dtype is not None:
mix_weights = mix_weights.to(dtype=dtype)
return Sum(inputs=hclts) if mix_weights is None else Sum(inputs=hclts, weights=mix_weights)
[docs]
def learn_hclt_categorical(
data: Tensor,
*,
num_hidden_cats: int,
num_cats: int | None = None,
num_trees: int = 1,
dropout_prob: float = 0.0,
weights: Tensor | None = None,
pseudocount: float = 1.0,
init: str = "uniform",
device: torch.device | None = None,
dtype: torch.dtype | None = None,
chunk_size_pairs: int = 4096,
) -> Module:
"""Learn an HCLT circuit from categorical data.
The structure is learned via a Chow-Liu tree on the observed variables, and
emissions are `Categorical(X_i | Z_i)` with `num_hidden_cats` latent states.
"""
if data.dim() != 2:
raise ShapeError(f"data must be 2D (N,F), got shape {tuple(data.shape)}.")
if torch.isnan(data).any():
raise InvalidParameterError("learn_hclt_categorical requires complete data (no NaNs).")
if init not in ("uniform", "random"):
raise InvalidParameterError("init must be 'uniform' or 'random'.")
if num_trees < 1:
raise InvalidParameterError("num_trees must be >= 1.")
if num_cats is None:
num_cats = int(data.max().item()) + 1 if data.numel() else 0
if num_cats <= 0:
raise InvalidParameterError("num_cats must be >= 1.")
num_features = int(data.shape[1])
trees = learn_chow_liu_trees_categorical(
data,
num_cats=num_cats,
num_trees=num_trees,
dropout_prob=dropout_prob,
weights=weights,
pseudocount=pseudocount,
chunk_size_pairs=chunk_size_pairs,
)
def emission_factory(var: int) -> Module:
leaf = Categorical(scope=Scope([var]), out_channels=num_hidden_cats, K=num_cats)
if device is not None:
leaf = leaf.to(device=device)
if dtype is not None:
leaf = leaf.to(dtype=dtype)
return leaf
hclts = [
_build_hclt_from_tree(
num_features=num_features,
edges=edges,
num_hidden_cats=num_hidden_cats,
emission_factory=emission_factory,
init=init,
device=device,
dtype=dtype,
)
for edges in trees
]
if len(hclts) == 1:
return hclts[0]
mix_weights = None
if init == "uniform":
mix_weights = torch.full((1, len(hclts), 1, 1), 1.0 / float(len(hclts)))
if device is not None:
mix_weights = mix_weights.to(device=device)
if dtype is not None:
mix_weights = mix_weights.to(dtype=dtype)
return Sum(inputs=hclts) if mix_weights is None else Sum(inputs=hclts, weights=mix_weights)