Criterions#
Criterions are specified in the Dataset configuration with Shorthand Syntax and are used to calculate the loss of the model.
Tip
To create custom criterions, refer to the custom criterions tutorial.
Note
The reduction
attribute of each criterion is automatically set to "none"
during training, validation, and testing.
This allows the per-example loss to be reported directly, without the need for re-calculating the loss for logging purposes.
This is handled automatically during the instantiation of the criterion.
Criterion Wrappers#
As the DataLoader may automatically cast to the wrong type, some loss functions need to be wrapped to cast the model outputs and targets to the correct types. For more information see this discussion.
autrainer provides the following wrappers:
- class autrainer.criterions.CrossEntropyLoss(weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean', label_smoothing=0.0)[source]#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- class autrainer.criterions.BCEWithLogitsLoss(weight=None, size_average=None, reduce=None, reduction='mean', pos_weight=None)[source]#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
Balanced and Weighted Criterions#
For imbalanced datasets, it is often beneficial to use a balanced or weighted loss function. Balanced loss functions automatically adjust the loss for each class or target based on their frequency using a setup function. Weighted loss functions allow for the manual specification of weights for each class or target.
- class autrainer.criterions.BalancedCrossEntropyLoss(weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean', label_smoothing=0.0)[source]#
Initialize internal Module state, shared by both nn.Module and ScriptModule.
- setup(data)[source]#
Calculate balanced weights for the dataset based on the target frequency in the training set.
- Parameters:
data (
AbstractDataset
) – Instance of the dataset.- Return type:
None
- class autrainer.criterions.WeightedCrossEntropyLoss(class_weights, **kwargs)[source]#
Wrapper for torch.nn.CrossEntropyLoss with manual class weights.
The class weights are automatically normalized to sum up to the number of classes.
- Parameters:
class_weights (
Dict
[str
,float
]) – Dictionary with class weights corresponding to the target labels and their respective weights.**kwargs – Additional keyword arguments passed to torch.nn.CrossEntropyLoss.
- setup(data)[source]#
Calculate the class weights based on the provided dictionary.
- Parameters:
data (
AbstractDataset
) – Instance of the dataset.- Return type:
None
- class autrainer.criterions.BalancedBCEWithLogitsLoss(weight=None, reduction='mean')[source]#
Balanced version of torch.nn.BCEWithLogitsLoss.
pos_weight is not supported, as the weights are calculated based on the target frequency in the training set.
- Parameters:
weight (
Optional
[Tensor
]) – A manual rescaling weight given to the positive class. Defaults to None.reduction (
str
) – Specifies the reduction to apply to the output. Defaults to ‘mean’.
- setup(data)[source]#
Calculate balanced weights for the dataset based on the target frequency in the training set.
- Parameters:
data (
AbstractDataset
) – Instance of the dataset.- Return type:
None
- class autrainer.criterions.WeightedBCEWithLogitsLoss(class_weights, **kwargs)[source]#
Wrapper for torch.nn.BCEWithLogitsLoss with manual class weights.
The class weights are automatically normalized to sum up to the number of classes.
- Parameters:
class_weights (
Dict
[str
,float
]) – Dictionary with class weights corresponding to the target labels and their respective weights.**kwargs – Additional keyword arguments passed to torch.nn.BCEWithLogitsLoss.
- setup(data)[source]#
Calculate the class weights based on the provided dictionary.
- Parameters:
data (
AbstractDataset
) – Instance of the dataset.- Return type:
None
- class autrainer.criterions.WeightedMSELoss(target_weights, **kwargs)[source]#
Wrapper for torch.nn.MSELoss with manual target weights intended for multi-target regression tasks.
The target weights are automatically normalized to sum up to the number of targets.
- Parameters:
target_weights (
Dict
[str
,float
]) – Dictionary with target weights corresponding to the target labels and their respective weights.**kwargs – Additional keyword arguments passed to torch.nn.MSELoss.
- setup(data)[source]#
Calculate the target weights based on the provided dictionary.
- Parameters:
data (
AbstractDataset
) – Instance of the dataset.- Return type:
None
Torch Criterions#
Other torch criterions, such as torch.nn.HuberLoss
or torch.nn.SmoothL1Loss
,
can be specified using Shorthand Syntax in the dataset configuration, analogous to the criterion wrappers.
Note
It may be necessary to wrap criterions similar to autrainer.criterions.CrossEntropyLoss
or autrainer.criterions.MSELoss
to cast the model outputs and targets to the correct types.
Criterion Setup#
Criterions can optionally provide a setup()
method which is called after the criterion is initialized
and takes the dataset instance as an argument.
This can be used to set up additional parameters, such as class weights for imbalanced datasets.
1class ExampleLoss(torch.nn.modules.loss._Loss):
2 def setup(self, data: "AbstractDataset") -> None: ...