Optimizers#

Any torch optimizer or custom optimizer can be used for training.

Tip

To create custom optimizers, refer to the custom optimizers tutorial.

Torch Optimizers#

Torch optimizers (torch.optim) can be specified as relative python import paths for the _target_ attribute in the configuration file. Any additional attributes (except id) are passed as keyword arguments to the optimizer constructor.

For example, torch.optim.Adam can be used as follows:

conf/optimizer/Adam.yaml#
1id: Adam
2_target_: torch.optim.Adam

autrainer provides a number of default configurations for torch optimizers:

Default Configurations

Adam

conf/optimizer/Adam.yaml#
1id: Adam
2_target_: torch.optim.Adam
conf/optimizer/Adamax.yaml#
1id: Adamax
2_target_: torch.optim.Adamax
conf/optimizer/AdamW.yaml#
1id: AdamW
2_target_: torch.optim.AdamW

SGD

conf/optimizer/SGD.yaml#
1id: SGD
2_target_: torch.optim.SGD

Custom Optimizers#

class autrainer.optimizers.SAM(params, base_optimizer, rho=0.05, adaptive=False, **kwargs)[source]#

Sharpness Aware Minimization (SAM) optimizer.

This implementation is adapted from the following repository: davda54/sam

For more information, see: https://arxiv.org/abs/2010.01412

Parameters:
  • params (Module) – Model parameters.

  • base_optimizer (str) – Underlying optimizer performing the sharpness-aware update, specified as a relative import path.

  • rho (float) – Size of the neighborhood for computing the max loss. Defaults to 0.05.

  • adaptive (bool) – Whether to use an experimental implementation of element-wise Adaptive SAM. Defaults to False.

  • **kwargs – Additional arguments passed to the underlying optimizer.

Default Configurations
conf/optimizer/SAM-SGD.yaml#
1id: SAM-SGD
2_target_: autrainer.optimizers.SAM
3base_optimizer: torch.optim.SGD

Custom Step Function#

Custom optimizers can optionally provide a custom_step() function that is called instead of the standard training step and should be defined as follows:

custom_step function of an optimizer#
 1class SomeOptimizer(torch.optim.Optimizer):
 2    def custom_step(
 3        self,
 4        model: torch.nn.Module,
 5        data: torch.Tensor,
 6        target: torch.Tensor,
 7        criterion: torch.nn.Module,
 8        probabilities_fn: Callable,
 9    ) -> Tuple[float, torch.Tensor]:
10        """Custom step function for the optimizer.
11
12        Args:
13            model: Model to be optimized.
14            data: Batched input data.
15            target: Batched target data.
16            criterion: Loss function.
17            probabilities_fn: Function to get probabilities from model outputs.
18
19        Returns:
20            Reduced loss over the batch and detached model outputs.
21        """

Note

The custom_step() function should perform both the forward and backward pass as well as update the model parameters accordingly.