Source code for spflow.learn.gradient_descent

import logging
from collections.abc import Callable

import torch
import torch.nn as nn
from torch import Tensor

from spflow.exceptions import InvalidTypeError
from spflow.interfaces.classifier import Classifier
from spflow.modules.module import Module

logger = logging.getLogger(__name__)


class TrainingMetrics:
    """Track training and validation metrics during model training.

    Attributes:
        train_losses: List of training batch losses.
        val_losses: List of validation batch losses.
        train_correct: Number of correctly predicted training samples.
        train_total: Total number of training samples processed.
        val_correct: Number of correctly predicted validation samples.
        val_total: Total number of validation samples processed.
        training_steps: Total number of training batches processed.
        validation_steps: Total number of validation batches processed.
    """

    def __init__(self) -> None:
        """Initialize a new TrainingMetrics instance.

        All metrics are initialized to zero or empty lists.
        """
        self.train_losses: list[Tensor] = []
        self.val_losses: list[Tensor] = []
        self.train_correct = 0
        self.train_total = 0
        self.val_correct = 0
        self.val_total = 0
        self.training_steps = 0
        self.validation_steps = 0

    def update_train_batch(
        self, loss: Tensor, predicted: Tensor | None = None, targets: Tensor | None = None
    ) -> None:
        """Update metrics after processing a training batch.

        Args:
            loss: The computed loss for the batch.
            predicted: Predicted class labels (optional, for classification).
            targets: Ground truth target labels (optional, for classification).
        """
        self.train_losses.append(loss)
        self.training_steps += 1
        if predicted is not None and targets is not None:
            self.train_total += targets.size(0)
            self.train_correct += (predicted == targets).sum().item()

    def update_val_batch(
        self, loss: Tensor, predicted: Tensor | None = None, targets: Tensor | None = None
    ) -> None:
        """Update metrics after processing a validation batch.

        Args:
            loss: The computed loss for the batch.
            predicted: Predicted class labels (optional, for classification).
            targets: Ground truth target labels (optional, for classification).
        """
        self.val_losses.append(loss)
        self.validation_steps += 1
        if predicted is not None and targets is not None:
            self.val_total += targets.size(0)
            self.val_correct += (predicted == targets).sum().item()

    def get_train_accuracy(self) -> float:
        """Calculate training accuracy percentage.

        Returns:
            float: Training accuracy as a percentage (0-100). Returns 0.0 if
            no training samples have been processed.
        """
        return 100 * self.train_correct / self.train_total if self.train_total > 0 else 0.0

    def get_val_accuracy(self) -> float:
        """Calculate validation accuracy percentage.

        Returns:
            float: Validation accuracy as a percentage (0-100). Returns 0.0 if
            no validation samples have been processed.
        """
        return 100 * self.val_correct / self.val_total if self.val_total > 0 else 0.0

    def reset_epoch_metrics(self) -> None:
        """Reset all epoch-specific metrics."""
        self.train_losses.clear()
        self.val_losses.clear()
        self.train_correct = 0
        self.train_total = 0
        self.val_correct = 0
        self.val_total = 0


def negative_log_likelihood_loss(model: Module, data: Tensor) -> torch.Tensor:
    """Compute negative log-likelihood loss.

    Args:
        model: Model to compute log-likelihood for.
        data: Input data tensor.

    Returns:
        torch.Tensor: Scalar negative log-likelihood loss tensor.
    """
    return -1 * model.log_likelihood(data).mean()


def classification_loss(ll: Tensor, target: Tensor) -> torch.Tensor:
    """Compute negative log-likelihood loss for classification tasks.

    Note:
        SPN models output log probabilities directly from their log_likelihood method,
        not raw logits like neural networks. Therefore, NLLLoss is the correct choice
        instead of CrossEntropyLoss. CrossEntropyLoss would apply log-softmax twice
        (once implicitly, once on already log-transformed probabilities), leading to
        incorrect results.

    Args:
        ll: Log-likelihood tensor with class probabilities.
        target: Target class labels as long tensor.

    Returns:
        torch.Tensor: Scalar negative log-likelihood loss tensor.
    """
    return nn.NLLLoss()(ll.squeeze(-1), target)


def _extract_batch_data(
    batch: tuple[Tensor, ...] | Tensor, is_classification: bool
) -> tuple[Tensor, Tensor | None]:
    """Extract data and targets from batch with proper error handling.

    Args:
        batch: Input batch from dataloader.
        is_classification: Whether this is a classification task.

    Returns:
        Tuple of (data, targets) where targets may be None for non-classification.

    Raises:
        ValueError: If batch format is invalid for the task type.
    """
    if is_classification:
        if not isinstance(batch, (tuple, list)) or len(batch) != 2:
            raise ValueError("Classification batches must be (data, targets) tuples")
        return batch[0], batch[1]

    # Handle non-classification batch formats
    if isinstance(batch, (tuple, list)):
        if len(batch) == 1:
            return batch[0], None
        elif len(batch) == 2:
            return batch[0], None  # Ignore second element
        else:
            raise ValueError("Non-classification batches should have 1 or 2 elements")
    else:
        return batch, None


def _process_training_batch(
    model: Module,
    batch: tuple[Tensor, ...] | Tensor,
    optimizer: torch.optim.Optimizer,
    loss_fn: Callable,
    metrics: TrainingMetrics,
    is_classification: bool,
    callback_batch: Callable[[Tensor, int], None] | None,
    nll_weight: float = 1.0,
) -> Tensor:
    """Process a single training batch and return the loss.

    Args:
        model: The model being trained.
        batch: Input batch from dataloader.
        optimizer: Optimizer for parameter updates.
        loss_fn: Loss function to compute.
        metrics: TrainingMetrics instance for tracking.
        is_classification: Whether this is a classification task.
        callback_batch: Optional callback function after each batch.
        nll_weight: Weight for the density estimation (NLL) term in classification tasks.

    Returns:
        The computed loss tensor.
    """
    # Clear gradients from previous step
    optimizer.zero_grad()
    data, targets = _extract_batch_data(batch, is_classification)

    # Compute loss based on task type (classification vs density estimation)
    if is_classification:
        # log_likelihood = model.log_likelihood(data)
        log_likelihood = model.predict_proba(data)
        loss = loss_fn(log_likelihood, targets) + nll_weight * (-model.log_likelihood(data).mean())
        predicted = torch.argmax(log_likelihood, dim=-1).squeeze()
        metrics.update_train_batch(loss, predicted, targets)
    else:
        loss = loss_fn(model, data)
        metrics.update_train_batch(loss)

    # Backpropagate and update weights
    loss.backward()
    optimizer.step()

    if callback_batch is not None:
        callback_batch(loss, metrics.training_steps)

    return loss


def _run_validation_epoch(
    model: Module,
    validation_dataloader: torch.utils.data.DataLoader,
    loss_fn: Callable,
    metrics: TrainingMetrics,
    is_classification: bool,
    callback_batch: Callable[[Tensor, int], None] | None,
    nll_weight: float = 1.0,
) -> Tensor:
    """Run validation epoch and return final validation loss.

    Args:
        model: The model being validated.
        validation_dataloader: DataLoader for validation data.
        loss_fn: Loss function to compute.
        metrics: TrainingMetrics instance for tracking.
        is_classification: Whether this is a classification task.
        callback_batch: Optional callback function after each batch.
        nll_weight: Weight for the density estimation (NLL) term in classification tasks.

    Returns:
        The final validation loss tensor from the last processed batch.
    """
    # Set model to evaluation mode
    model.eval()
    val_loss: Tensor

    # Validate without computing gradients
    with torch.no_grad():
        for batch in validation_dataloader:
            data, targets = _extract_batch_data(batch, is_classification)

            if is_classification:
                log_likelihood = model.log_likelihood(data)
                val_loss = loss_fn(log_likelihood, targets) + nll_weight * negative_log_likelihood_loss(
                    model, data
                )
                predicted = torch.argmax(log_likelihood, dim=-1).squeeze()
                metrics.update_val_batch(val_loss, predicted, targets)
            else:
                val_loss = loss_fn(model, data)
                metrics.update_val_batch(val_loss)

            if callback_batch is not None:
                callback_batch(val_loss, metrics.training_steps)

    # Return to training mode
    model.train()
    return val_loss


[docs] def train_gradient_descent( model: Module, dataloader: torch.utils.data.DataLoader, epochs: int = -1, verbose: bool = False, is_classification: bool = False, optimizer: torch.optim.Optimizer | None = None, scheduler: torch.optim.lr_scheduler.LRScheduler | None = None, lr: float = 1e-3, loss_fn: Callable[[Module, Tensor], Tensor] | None = None, validation_dataloader: torch.utils.data.DataLoader | None = None, callback_batch: Callable[[Tensor, int], None] | None = None, callback_epoch: Callable[[list[Tensor], int], None] | None = None, nll_weight: float = 1.0, ): """Train model using gradient descent. Args: model: Model to train, must inherit from Module. dataloader: Training data loader yielding batches. epochs: Number of training epochs. Must be positive. verbose: Whether to log training progress per epoch. is_classification: Whether this is a classification task. optimizer: Optimizer instance. Defaults to Adam if None. scheduler: Learning rate scheduler. Defaults to MultiStepLR if None. lr: Learning rate for default Adam optimizer. loss_fn: Custom loss function. Defaults based on task type if None. validation_dataloader: Validation data loader for periodic evaluation. callback_batch: Function called after each batch with (loss, step). callback_epoch: Function called after each epoch with (losses, epoch). nll_weight: Weight for the density estimation (NLL) term when is_classification=True. Controls the balance between discriminative and generative loss. Default is 1.0. Raises: ValueError: If epochs is not a positive integer. InvalidTypeError: If is_classification is True and model is not a Classifier instance. """ # Input validation if epochs <= 0: raise ValueError("epochs must be a positive integer") if is_classification and not isinstance(model, Classifier): raise InvalidTypeError("model must be a Classifier instance when is_classification=True") # Initialize components model.train() if optimizer is None: optimizer = torch.optim.Adam(model.parameters(), lr=lr) if scheduler is None: scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer, milestones=[int(epochs * 0.5), int(epochs * 0.75)], gamma=0.1 ) # Initialize loss function based on task type if loss_fn is None: loss_fn = classification_loss if is_classification else negative_log_likelihood_loss metrics = TrainingMetrics() # Training loop for epoch in range(epochs): metrics.reset_epoch_metrics() # Process training batches for batch in dataloader: loss = _process_training_batch( model, batch, optimizer, loss_fn, metrics, is_classification, callback_batch, nll_weight ) scheduler.step() # Log training metrics if is_classification: logger.debug(f"Accuracy: {metrics.get_train_accuracy():.2f}%") # Run validation if validation_dataloader is not None and epoch % 10 == 0: val_loss = _run_validation_epoch( model, validation_dataloader, loss_fn, metrics, is_classification, callback_batch, nll_weight ) logger.debug(f"Validation Loss: {val_loss.item()}") if is_classification: logger.debug(f"Validation Accuracy: {metrics.get_val_accuracy():.2f}%") # Epoch callback and logging if callback_epoch is not None: callback_epoch(metrics.train_losses, epoch) if verbose: logger.info(f"Epoch [{epoch}/{epochs}]: Loss: {loss.item()}")