Criterions#

Criterions are specified in the Dataset configuration with Shorthand Syntax and are used to calculate the loss of the model.

Tip

To create custom criterions, refer to the custom criterions tutorial.

Criterion Wrappers#

As the DataLoader may automatically cast to the wrong type, some loss functions need to be wrapped to cast the model outputs and targets to the correct types. For more information see this discussion.

autrainer provides the following wrappers:

class autrainer.criterions.CrossEntropyLoss(weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean', label_smoothing=0.0)[source]#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x, y)[source]#

Wrapper for torch.nn.CrossEntropyLoss.forward.

Converts the targets to long if it is a 1D tensor.

Parameters:
  • x (Tensor) – Batched model outputs.

  • y (Tensor) – Targets.

Return type:

Tensor

Returns:

Loss.

class autrainer.criterions.BalancedCrossEntropyLoss(weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean', label_smoothing=0.0)[source]#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

setup(data)[source]#

Calculate balanced weights for the dataset based on the target frequency in the training set.

Parameters:

data (AbstractDataset) – Instance of the dataset.

Return type:

None

class autrainer.criterions.MSELoss(size_average=None, reduce=None, reduction='mean')[source]#

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(x, y)[source]#

Wrapper for torch.nn.MSELoss.forward.

Squeezes the model outputs along the last dimension to match the shape of the targets. Converts the targets to float.

Parameters:
  • x (Tensor) – Batched model outputs.

  • y (Tensor) – Targets.

Return type:

Tensor

Returns:

Loss.

Torch Criterions#

Torch criterions, such as torch.nn.BCEWithLogitsLoss for multi-label classification tasks, can be specified using Shorthand Syntax in the dataset configuration, analogous to the criterion wrappers.

Note

It may be necessary to wrap the criterion similar to autrainer.criterions.CrossEntropyLoss or autrainer.criterions.MSELoss to cast the model outputs and targets to the correct types.

Criterion Setup#

Criterions can optionally provide a setup() method which is called after the criterion is initialized and takes the dataset instance as an argument. This can be used to set up additional parameters, such as class weights for imbalanced datasets.

example_loss.ExampleLoss#
1class ExampleLoss(torch.nn.modules.loss._Loss):
2    def setup(self, data: "AbstractDataset") -> None: ...