Source code for spflow.interfaces.classifier

"""Abstract base class for classification modules."""

from abc import ABC, abstractmethod

import torch


[docs] class Classifier(ABC): """Abstract base class for modules that support classification. Provides a standard interface for models that can predict class labels and class probabilities from input data. """
[docs] @abstractmethod def predict_proba(self, data: torch.Tensor) -> torch.Tensor: """Predict class probabilities for input data. Args: data: Input data tensor. Returns: Class probability predictions. Each row corresponds to a data point, and each column corresponds to a class. """ pass
[docs] def predict(self, data: torch.Tensor) -> torch.Tensor: """Predict class labels for input data. Args: data: Input data tensor. Returns: Predicted class labels. """ return torch.argmax(self.predict_proba(data), dim=1)