Source code for spflow.learn.expectation_maximization
import logging
import torch
from torch import Tensor
from torch.utils.data import DataLoader
from spflow.modules.module import Module
from spflow.utils.cache import Cache
logger = logging.getLogger(__name__)
[docs]
def expectation_maximization(
module: Module,
data: Tensor,
max_steps: int = -1,
verbose: bool = False,
) -> Tensor:
"""Performs expectation-maximization optimization on a given module.
Args:
module: Module to perform EM optimization on.
data: Two-dimensional tensor containing the input data. Each row corresponds to a sample.
max_steps: Maximum number of iterations. Defaults to -1, in which case optimization runs until convergence.
verbose: Whether to print the log-likelihood for each iteration step. Defaults to False.
Returns:
One-dimensional tensor containing the average log-likelihood for each iteration step.
"""
prev_avg_ll = torch.tensor(-float("inf"))
ll_history = []
if max_steps == -1:
max_steps = 2**64 - 1
for step in range(max_steps):
# Shared cache for this EM iteration
cache = Cache()
# compute log likelihoods and sum them together
module_lls = module.log_likelihood(data, cache=cache)
acc_ll = module_lls.sum()
avg_ll = acc_ll.detach().clone() / data.shape[0]
ll_history.append(avg_ll)
if verbose:
logger.info(f"Step {step}: Average log-likelihood: {avg_ll}")
# retain gradients for all module log-likelihoods
for lls in cache["log_likelihood"].values():
if torch.is_tensor(lls) and lls.requires_grad:
lls.retain_grad()
# compute gradients (if there are differentiable parameters to begin with)
if acc_ll.requires_grad:
acc_ll.backward(retain_graph=True)
# recursively perform expectation maximization
module.expectation_maximization(data, cache=cache)
# end update loop if max steps reached or loss converged
if avg_ll <= prev_avg_ll:
if verbose:
logger.info(f"EM converged after {step} steps.")
break
prev_avg_ll = avg_ll
return torch.stack(ll_history)
def expectation_maximization_batched(
module: Module,
dataloader: DataLoader,
num_epochs: int = 1,
verbose: bool = False,
) -> Tensor:
"""Runs expectation-maximization over multiple epochs using mini-batches.
Args:
module: Module to perform EM optimization on.
dataloader: Dataloader yielding batches of input data tensors.
num_epochs: Number of epochs to iterate over the dataloader.
verbose: Whether to print the average log-likelihood per epoch.
Returns:
One-dimensional tensor containing the average log-likelihood for each epoch.
"""
ll_history = []
for epoch in range(num_epochs):
epoch_ll = None
num_samples = 0
for batch in dataloader:
batch_data = batch[0] if isinstance(batch, (list, tuple)) else batch
cache = Cache()
module_lls = module.log_likelihood(batch_data, cache=cache)
acc_ll = module_lls.sum()
if epoch_ll is None:
epoch_ll = torch.zeros((), device=module_lls.device, dtype=module_lls.dtype)
epoch_ll = epoch_ll + acc_ll.detach()
num_samples += batch_data.shape[0]
for lls in cache["log_likelihood"].values():
if torch.is_tensor(lls) and lls.requires_grad:
lls.retain_grad()
if acc_ll.requires_grad:
acc_ll.backward(retain_graph=True)
module.expectation_maximization(batch_data, cache=cache)
if epoch_ll is None or num_samples == 0:
avg_ll = torch.tensor(float("nan"))
else:
avg_ll = epoch_ll / num_samples
ll_history.append(avg_ll)
if verbose:
logger.info(f"Epoch {epoch}: Average log-likelihood: {avg_ll}")
return torch.stack(ll_history)