import logging
from collections.abc import Callable
import torch
import torch.nn as nn
from torch import Tensor
from spflow.exceptions import InvalidTypeError
from spflow.interfaces.classifier import Classifier
from spflow.modules.module import Module
from spflow.utils.cache import Cache
logger = logging.getLogger(__name__)
class TrainingMetrics:
"""Track training and validation metrics during model training.
Attributes:
train_losses: List of training batch losses.
val_losses: List of validation batch losses.
train_correct: Number of correctly predicted training samples.
train_total: Total number of training samples processed.
val_correct: Number of correctly predicted validation samples.
val_total: Total number of validation samples processed.
training_steps: Total number of training batches processed.
validation_steps: Total number of validation batches processed.
"""
def __init__(self) -> None:
"""Initialize a new TrainingMetrics instance.
All metrics are initialized to zero or empty lists.
"""
self.train_losses: list[Tensor] = []
self.val_losses: list[Tensor] = []
self.train_correct = 0
self.train_total = 0
self.val_correct = 0
self.val_total = 0
self.training_steps = 0
self.validation_steps = 0
def update_train_batch(
self, loss: Tensor, predicted: Tensor | None = None, targets: Tensor | None = None
) -> None:
"""Update metrics after processing a training batch.
Args:
loss: The computed loss for the batch.
predicted: Predicted class labels (optional, for classification).
targets: Ground truth target labels (optional, for classification).
"""
self.train_losses.append(loss)
self.training_steps += 1
if predicted is not None and targets is not None:
self.train_total += targets.size(0)
self.train_correct += (predicted == targets).sum().item()
def update_val_batch(
self, loss: Tensor, predicted: Tensor | None = None, targets: Tensor | None = None
) -> None:
"""Update metrics after processing a validation batch.
Args:
loss: The computed loss for the batch.
predicted: Predicted class labels (optional, for classification).
targets: Ground truth target labels (optional, for classification).
"""
self.val_losses.append(loss)
self.validation_steps += 1
if predicted is not None and targets is not None:
self.val_total += targets.size(0)
self.val_correct += (predicted == targets).sum().item()
def get_train_accuracy(self) -> float:
"""Calculate training accuracy percentage.
Returns:
float: Training accuracy as a percentage (0-100). Returns 0.0 if
no training samples have been processed.
"""
return 100 * self.train_correct / self.train_total if self.train_total > 0 else 0.0
def get_val_accuracy(self) -> float:
"""Calculate validation accuracy percentage.
Returns:
float: Validation accuracy as a percentage (0-100). Returns 0.0 if
no validation samples have been processed.
"""
return 100 * self.val_correct / self.val_total if self.val_total > 0 else 0.0
def reset_epoch_metrics(self) -> None:
"""Reset all epoch-specific metrics."""
self.train_losses.clear()
self.val_losses.clear()
self.train_correct = 0
self.train_total = 0
self.val_correct = 0
self.val_total = 0
def negative_log_likelihood_loss(model: Module, data: Tensor) -> torch.Tensor:
"""Compute negative log-likelihood loss.
Args:
model: Model to compute log-likelihood for.
data: Input data tensor.
Returns:
torch.Tensor: Scalar negative log-likelihood loss tensor.
"""
return -1 * model.log_likelihood(data).mean()
def classification_loss(ll: Tensor, target: Tensor) -> torch.Tensor:
"""Compute negative log-likelihood loss for classification tasks.
Note:
SPN models output log probabilities directly from their log_likelihood method,
not raw logits like neural networks. Therefore, NLLLoss is the correct choice
instead of CrossEntropyLoss. CrossEntropyLoss would apply log-softmax twice
(once implicitly, once on already log-transformed probabilities), leading to
incorrect results.
Args:
ll: Log-likelihood tensor with class probabilities.
target: Target class labels as long tensor.
Returns:
torch.Tensor: Scalar negative log-likelihood loss tensor.
"""
return nn.NLLLoss()(ll.squeeze(-1), target)
def _classifier_log_probabilities(model: Module, data: Tensor) -> Tensor:
"""Return classifier predictions in log-probability space."""
log_posterior_fn = getattr(model, "log_posterior", None)
if callable(log_posterior_fn):
return log_posterior_fn(data)
probs = model.predict_proba(data) # type: ignore[attr-defined]
tiny = torch.finfo(probs.dtype).tiny
return probs.clamp_min(tiny).log()
def _classifier_log_probabilities_cached(model: Module, data: Tensor, cache: Cache) -> Tensor:
"""Return classifier predictions in log-probability space using a shared traversal cache."""
log_posterior_fn = getattr(model, "log_posterior", None)
if callable(log_posterior_fn):
return log_posterior_fn(data, cache=cache)
probs = model.predict_proba(data) # type: ignore[attr-defined]
tiny = torch.finfo(probs.dtype).tiny
return probs.clamp_min(tiny).log()
def _classification_forward_outputs(model: Module, data: Tensor) -> tuple[Tensor, Tensor]:
"""Return classification log-probabilities and generative log-likelihood for one batch."""
cache = Cache()
log_probs = _classifier_log_probabilities_cached(model, data, cache=cache)
log_likelihood = model.log_likelihood(data, cache=cache)
return log_probs, log_likelihood
def _extract_batch_data(
batch: tuple[Tensor, ...] | Tensor, is_classification: bool
) -> tuple[Tensor, Tensor | None]:
"""Extract data and targets from batch with proper error handling.
Args:
batch: Input batch from dataloader.
is_classification: Whether this is a classification task.
Returns:
Tuple of (data, targets) where targets may be None for non-classification.
Raises:
ValueError: If batch format is invalid for the task type.
"""
if is_classification:
if not isinstance(batch, (tuple, list)) or len(batch) != 2:
raise ValueError("Classification batches must be (data, targets) tuples")
return batch[0], batch[1]
# Handle non-classification batch formats
if isinstance(batch, (tuple, list)):
if len(batch) == 1:
return batch[0], None
elif len(batch) == 2:
return batch[0], None # Ignore second element
else:
raise ValueError("Non-classification batches should have 1 or 2 elements")
else:
return batch, None
def _process_training_batch(
model: Module,
batch: tuple[Tensor, ...] | Tensor,
optimizer: torch.optim.Optimizer,
loss_fn: Callable,
metrics: TrainingMetrics,
is_classification: bool,
callback_batch: Callable[[Tensor, int], None] | None,
nll_weight: float = 1.0,
) -> Tensor:
"""Process a single training batch and return the loss.
Args:
model: The model being trained.
batch: Input batch from dataloader.
optimizer: Optimizer for parameter updates.
loss_fn: Loss function to compute.
metrics: TrainingMetrics instance for tracking.
is_classification: Whether this is a classification task.
callback_batch: Optional callback function after each batch.
nll_weight: Weight for the density estimation (NLL) term in classification tasks.
Returns:
The computed loss tensor.
"""
# Clear gradients from previous step
optimizer.zero_grad()
data, targets = _extract_batch_data(batch, is_classification)
# Compute loss based on task type (classification vs density estimation)
if is_classification:
log_probs, log_likelihood = _classification_forward_outputs(model, data)
loss = loss_fn(log_probs, targets) + nll_weight * (-log_likelihood.mean())
predicted = torch.argmax(log_probs, dim=-1).squeeze()
metrics.update_train_batch(loss, predicted, targets)
else:
loss = loss_fn(model, data)
metrics.update_train_batch(loss)
# Backpropagate and update weights
loss.backward()
optimizer.step()
if callback_batch is not None:
callback_batch(loss, metrics.training_steps)
return loss
def _run_validation_epoch(
model: Module,
validation_dataloader: torch.utils.data.DataLoader,
loss_fn: Callable,
metrics: TrainingMetrics,
is_classification: bool,
callback_batch: Callable[[Tensor, int], None] | None,
nll_weight: float = 1.0,
) -> Tensor:
"""Run validation epoch and return final validation loss.
Args:
model: The model being validated.
validation_dataloader: DataLoader for validation data.
loss_fn: Loss function to compute.
metrics: TrainingMetrics instance for tracking.
is_classification: Whether this is a classification task.
callback_batch: Optional callback function after each batch.
nll_weight: Weight for the density estimation (NLL) term in classification tasks.
Returns:
The final validation loss tensor from the last processed batch.
"""
# Set model to evaluation mode
model.eval()
val_loss: Tensor
# Validate without computing gradients
with torch.no_grad():
for batch in validation_dataloader:
data, targets = _extract_batch_data(batch, is_classification)
if is_classification:
log_probs, log_likelihood = _classification_forward_outputs(model, data)
val_loss = loss_fn(log_probs, targets) + nll_weight * (-log_likelihood.mean())
predicted = torch.argmax(log_probs, dim=-1).squeeze()
metrics.update_val_batch(val_loss, predicted, targets)
else:
val_loss = loss_fn(model, data)
metrics.update_val_batch(val_loss)
if callback_batch is not None:
callback_batch(val_loss, metrics.training_steps)
# Return to training mode
model.train()
return val_loss
[docs]
def train_gradient_descent(
model: Module,
dataloader: torch.utils.data.DataLoader,
epochs: int = -1,
verbose: bool = False,
is_classification: bool = False,
optimizer: torch.optim.Optimizer | None = None,
scheduler: torch.optim.lr_scheduler.LRScheduler | None = None,
lr: float = 1e-3,
loss_fn: Callable[[Module, Tensor], Tensor] | None = None,
validation_dataloader: torch.utils.data.DataLoader | None = None,
callback_batch: Callable[[Tensor, int], None] | None = None,
callback_epoch: Callable[[list[Tensor], int], None] | None = None,
nll_weight: float = 1.0,
):
"""Train model using gradient descent.
Args:
model: Model to train, must inherit from Module.
dataloader: Training data loader yielding batches.
epochs: Number of training epochs. Must be positive.
verbose: Whether to log training progress per epoch.
is_classification: Whether this is a classification task.
optimizer: Optimizer instance. Defaults to Adam if None.
scheduler: Learning rate scheduler. Defaults to MultiStepLR if None.
lr: Learning rate for default Adam optimizer.
loss_fn: Custom loss function. Defaults based on task type if None.
validation_dataloader: Validation data loader for periodic evaluation.
callback_batch: Function called after each batch with (loss, step).
callback_epoch: Function called after each epoch with (losses, epoch).
nll_weight: Weight for the density estimation (NLL) term when is_classification=True.
Controls the balance between discriminative and generative loss. Default is 1.0.
Raises:
ValueError: If epochs is not a positive integer.
InvalidTypeError: If is_classification is True and model is not a Classifier instance.
"""
# Input validation
if epochs <= 0:
raise ValueError("epochs must be a positive integer")
if is_classification and not isinstance(model, Classifier):
raise InvalidTypeError("model must be a Classifier instance when is_classification=True")
# Initialize components
model.train()
if optimizer is None:
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
if scheduler is None:
scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer, milestones=[int(epochs * 0.5), int(epochs * 0.75)], gamma=0.1
)
# Initialize loss function based on task type
if loss_fn is None:
loss_fn = classification_loss if is_classification else negative_log_likelihood_loss
metrics = TrainingMetrics()
# Training loop
for epoch in range(epochs):
metrics.reset_epoch_metrics()
# Process training batches
for batch in dataloader:
loss = _process_training_batch(
model, batch, optimizer, loss_fn, metrics, is_classification, callback_batch, nll_weight
)
scheduler.step()
# Log training metrics
if is_classification:
logger.debug(f"Accuracy: {metrics.get_train_accuracy():.2f}%")
# Run validation
if validation_dataloader is not None and epoch % 10 == 0:
val_loss = _run_validation_epoch(
model, validation_dataloader, loss_fn, metrics, is_classification, callback_batch, nll_weight
)
logger.debug(f"Validation Loss: {val_loss.item()}")
if is_classification:
logger.debug(f"Validation Accuracy: {metrics.get_val_accuracy():.2f}%")
# Epoch callback and logging
if callback_epoch is not None:
callback_epoch(metrics.train_losses, epoch)
if verbose:
logger.info(f"Epoch [{epoch}/{epochs}]: Loss: {loss.item()}")