Metrics Collection#
This module contains the code for a pre-defined combination of torchmetrics
metrics that are used in the training and
evaluation of the models. The metrics can be used for classification tasks within or independent of the ConfigILM
framework. They are a convenience feature to standardize the metrics used in the training and evaluation of the models
as well as to provide a common interface for the metrics and reduce boilerplate code. They are a simple wrapper around
the torchmetrics
library and can be used in the same way as the metrics from the torchmetrics
library.
- configilm.metrics.get_classification_metric_collection(task, average=None, num_classes=None, num_labels=None, exclude=None, prefix=None)#
Get a collection of classification metrics. The metrics are chosen based on the task and the average parameter. By default, all metrics are included. The exclude parameter can be used to exclude specific metrics. For some combinations of task and average, some metrics are not available and will be excluded automatically.
Right now, only multilabel classification is implemented. The other tasks are placeholders for future implementations.
- The default metrics are:
Accuracy
AUROC
AveragePrecision
F1Score
F2Score
Precision
Recall
Specificity
- The following metrics are removed by default:
AUROC and AveragePrecision for task=”multilabel” and average=”sample”
- Parameters:
task (Literal['binary', 'multiclass', 'multilabel']) – The classification task. One of “binary”, “multiclass”, or “multilabel”.
average (Optional[Literal['macro', 'micro', 'sample']]) – The averaging strategy to be used. One of “macro”, “micro”, “sample”, or None.
num_classes (Optional[int]) – The number of classes in the dataset. Required for multiclass classification.
num_labels (Optional[int]) – The number of labels in the dataset. Required for multilabel classification.
exclude (Optional[list[str]]) – A list of metric names to exclude from the collection.
prefix (Optional[str]) – A prefix to add to the metric names.
- Returns:
A MetricCollection object containing the chosen metrics.
- Return type:
MetricCollection