Criterions#
Criterions are specified in the Dataset configuration with Shorthand Syntax and are used to calculate the loss of the model.
Tip
To create custom criterions, refer to the custom criterions tutorial.
Criterion Wrappers#
As the DataLoader may automatically cast to the wrong type, some loss functions need to be wrapped to cast the model outputs and targets to the correct types. For more information see this discussion.
autrainer provides the following wrappers:
- class autrainer.criterions.CrossEntropyLoss(weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean', label_smoothing=0.0)[source]#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- class autrainer.criterions.BalancedCrossEntropyLoss(weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean', label_smoothing=0.0)[source]#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- setup(data)[source]#
Calculate balanced weights for the dataset based on the target frequency in the training set.
- Parameters:
data (
AbstractDataset
) – Instance of the dataset.- Return type:
None
Torch Criterions#
Torch criterions, such as torch.nn.BCEWithLogitsLoss
for multi-label classification tasks,
can be specified using Shorthand Syntax in the dataset configuration, analogous to the criterion wrappers.
Note
It may be necessary to wrap the criterion similar to autrainer.criterions.CrossEntropyLoss
or autrainer.criterions.MSELoss
to cast the model outputs and targets to the correct types.
Criterion Setup#
Criterions can optionally provide a setup()
method which is called after the criterion is initialized
and takes the dataset instance as an argument.
This can be used to set up additional parameters, such as class weights for imbalanced datasets.
1class ExampleLoss(torch.nn.modules.loss._Loss):
2 def setup(self, data: "AbstractDataset") -> None: ...