Optimizers#

Any torch optimizer or custom optimizer can be used for training.

Tip

To create custom optimizers, refer to the custom optimizers tutorial.

Torch Optimizers#

Torch optimizers (torch.optim) can be specified as relative python import paths for the _target_ attribute in the configuration file. Any additional attributes (except id) are passed as keyword arguments to the optimizer constructor.

For example, torch.optim.Adam can be used as follows:

conf/optimizer/Adam.yaml#
1id: Adam
2_target_: torch.optim.Adam

autrainer provides a number of default configurations for torch optimizers:

Default Configurations
conf/optimizer/SGD.yaml#
1id: SGD
2_target_: torch.optim.SGD
conf/optimizer/Adam.yaml#
1id: Adam
2_target_: torch.optim.Adam
conf/optimizer/AdamW.yaml#
1id: AdamW
2_target_: torch.optim.AdamW
conf/optimizer/Adadelta.yaml#
1id: Adadelta
2_target_: torch.optim.Adadelta
conf/optimizer/Adagrad.yaml#
1id: Adagrad
2_target_: torch.optim.Adagrad
conf/optimizer/Adamax.yaml#
1id: Adamax
2_target_: torch.optim.Adamax
conf/optimizer/NAdam.yaml#
1id: NAdam
2_target_: torch.optim.NAdam
conf/optimizer/RAdam.yaml#
1id: RAdam
2_target_: torch.optim.RAdam
conf/optimizer/SparseAdam.yaml#
1id: SparseAdam
2_target_: torch.optim.SparseAdam
conf/optimizer/RMSprop.yaml#
1id: RMSprop
2_target_: torch.optim.RMSprop
conf/optimizer/Rprop.yaml#
1id: Rprop
2_target_: torch.optim.Rprop
conf/optimizer/ASGD.yaml#
1id: ASGD
2_target_: torch.optim.ASGD
conf/optimizer/LBFGS.yaml#
1id: LBFGS
2_target_: torch.optim.LBFGS

Custom Optimizers#

class autrainer.optimizers.SAM(params, base_optimizer, rho=0.05, adaptive=False, **kwargs)[source]#

Sharpness Aware Minimization (SAM) optimizer.

This implementation is adapted from the following repository: davda54/sam

For more information, see: https://arxiv.org/abs/2010.01412

Parameters:
  • params (Module) – Model parameters.

  • base_optimizer (str) – Underlying optimizer performing the sharpness-aware update, specified as a relative import path.

  • rho (float) – Size of the neighborhood for computing the max loss. Defaults to 0.05.

  • adaptive (bool) – Whether to use an experimental implementation of element-wise Adaptive SAM. Defaults to False.

  • **kwargs – Additional arguments passed to the underlying optimizer.

Default Configurations
conf/optimizer/SAM-SGD.yaml#
1id: SAM-SGD
2_target_: autrainer.optimizers.SAM
3base_optimizer: torch.optim.SGD

Custom Step Function#

Custom optimizers can optionally provide a custom_step() function that is called instead of the standard training step and should be defined as follows:

custom_step function of an optimizer#
 1class SomeOptimizer(torch.optim.Optimizer):
 2    def custom_step(
 3        self,
 4        model: AbstractModel,
 5        data: DataBatch,
 6        criterion: torch.nn.Module,
 7        probabilities_fn: Callable,
 8    ) -> Tuple[torch.Tensor, torch.Tensor]:
 9        """Custom step function for the optimizer.
10
11        Args:
12            model: The model to train.
13            data: The data batch containing features, target, and potentially
14                additional fields. The data batch is on the same
15                device as the model. Additional fields are passed to the model
16                as keyword arguments if they are present in the model's forward
17                method.
18            criterion: Loss function.
19            probabilities_fn: Function to convert model outputs to
20                probabilities.
21
22        Returns:
23            Tuple containing the non-reduced loss and model outputs.
24        """

Note

The custom_step() function should perform both the forward and backward pass as well as update the model parameters accordingly.