Metrics#

Metrics are specified in the dataset configuration using shorthand syntax. The tracking_metric attribute specifies the metric to be used for early stopping. The metrics attribute in the dataset configuration specifies a list of metrics to be used during training.

Tip

To create custom metrics, refer to the custom metrics tutorial.

For example, a dataset for classification could specify the following metrics:

conf/dataset/ExampleDataset.yaml#
1id: ExampleDataset
2_target_: example_dataset.ExampleDataset
3...
4
5tracking_metric: autrainer.metrics.Accuracy
6metrics:
7  - autrainer.metrics.Accuracy
8  - autrainer.metrics.UAR
9  - autrainer.metrics.F1

Abstract Metric#

class autrainer.metrics.AbstractMetric(name, fn, fallback, **fn_kwargs)[source]#

Abstract class for metrics.

Parameters:
  • name (str) – The name of the metric.

  • fn (Callable) – The function to compute the metric.

  • fallback (float) – The fallback value if the metric is NaN.

  • **fn_kwargs (dict) – Additional keyword arguments to pass to the function.

__call__(*args, **kwargs)[source]#

Compute the metric.

Parameters:
  • *args – Positional arguments for the function.

  • **kwargs – Keyword arguments for the function.

Return type:

float

Returns:

The score.

abstract property starting_metric: float#

The starting metric value.

Returns:

The starting metric value.

abstract property suffix: str#

The suffix of the metric.

Returns:

The suffix of the metric.

abstract static get_best(a)[source]#

Get the best metric value from a series of scores.

Parameters:

a (Union[Series, ndarray]) – Pandas series or numpy array of scores.

Return type:

float

Returns:

Best metric value.

abstract static get_best_pos(a)[source]#

Get the position of the best metric value from a series of scores.

Parameters:

a (Union[Series, ndarray]) – Pandas series or numpy array of scores.

Return type:

int

Returns:

Position of the best metric value.

abstract static compare(a, b)[source]#

Compare two scores and return True if the first score is better.

Parameters:
  • a (Union[int, float]) – First score.

  • b (Union[int, float]) – Second score.

Return type:

bool

Returns:

True if the first score is better.

class autrainer.metrics.BaseAscendingMetric(name, fn, fallback=None, **fn_kwargs)[source]#

Base for ascending metrics with higher values being better.

Parameters:
  • name (str) – The name of the metric.

  • fn (Callable) – The function to compute the metric.

  • fallback (Optional[float]) – The fallback value if the metric is NaN. If None, the fallback value is set to -1e32. Defaults to None.

  • **fn_kwargs (dict) – Additional keyword arguments to pass to the function.

property starting_metric: float#

Ascending metric starting value.

Returns:

-1e32

property suffix: str#

Ascending metric suffix.

Returns:

“max”

class autrainer.metrics.BaseDescendingMetric(name, fn, fallback=None, **fn_kwargs)[source]#

Base for descending metrics with lower values being better.

Parameters:
  • name (str) – The name of the metric.

  • fn (Callable) – The function to compute the metric.

  • fallback (Optional[float]) – The fallback value if the metric is NaN. If None, the fallback value is set to 1e32. Defaults to None.

  • **fn_kwargs (dict) – Additional keyword arguments to pass to the function.

property starting_metric: float#

Descending metric starting value.

Returns:

1e32

property suffix: str#

Descending metric suffix.

Returns:

“min”

Classification Metrics#

class autrainer.metrics.Accuracy[source]#

Accuracy metric using audmetric.accuracy.

class autrainer.metrics.UAR[source]#

Unweighted average recall metric using audmetric.unweighted_average_recall.

class autrainer.metrics.F1[source]#

F1 metric using audmetric.unweighted_average_fscore.

Regression Metrics#

class autrainer.metrics.CCC[source]#

Concordance correlation coefficient metric using audmetric.concordance_cc.

class autrainer.metrics.MAE[source]#

Mean absolute error metric using audmetric.mean_absolute_error.

class autrainer.metrics.MSE[source]#

Mean squared error metric using audmetric.mean_squared_error.

class autrainer.metrics.PCC[source]#

Pearson correlation coefficient metric using audmetric.pearson_cc.

Multi-label Classification Metrics#

class autrainer.metrics.MLAccuracy[source]#

Accuracy metric using sklearn.metrics.accuracy_score.

class autrainer.metrics.MLF1Macro[source]#

F1 macro metric using sklearn.metrics.f1_score.

class autrainer.metrics.MLF1Micro[source]#

F1 micro metric using sklearn.metrics.f1_score.

class autrainer.metrics.MLF1Weighted[source]#

F1 weighted metric using sklearn.metrics.f1_score.