Optimizers#
Any torch optimizer or custom optimizer can be used for training.
Tip
To create custom optimizers, refer to the custom optimizers tutorial.
Torch Optimizers#
Torch optimizers (torch.optim
) can be specified as relative python import paths for the _target_
attribute in the configuration file.
Any additional attributes (except id
) are passed as keyword arguments to the optimizer constructor.
For example, torch.optim.Adam
can be used as follows:
1id: Adam
2_target_: torch.optim.Adam
autrainer provides a number of default configurations for torch optimizers:
Default Configurations
Adam
1id: Adam
2_target_: torch.optim.Adam
1id: Adamax
2_target_: torch.optim.Adamax
1id: AdamW
2_target_: torch.optim.AdamW
SGD
1id: SGD
2_target_: torch.optim.SGD
Custom Optimizers#
- class autrainer.optimizers.SAM(params, base_optimizer, rho=0.05, adaptive=False, **kwargs)[source]#
Sharpness Aware Minimization (SAM) optimizer.
This implementation is adapted from the following repository: davda54/sam
For more information, see: https://arxiv.org/abs/2010.01412
- Parameters:
params (
Module
) – Model parameters.base_optimizer (
str
) – Underlying optimizer performing the sharpness-aware update, specified as a relative import path.rho (
float
) – Size of the neighborhood for computing the max loss. Defaults to 0.05.adaptive (
bool
) – Whether to use an experimental implementation of element-wise Adaptive SAM. Defaults to False.**kwargs – Additional arguments passed to the underlying optimizer.
Default Configurations
1id: SAM-SGD 2_target_: autrainer.optimizers.SAM 3base_optimizer: torch.optim.SGD
Custom Step Function#
Custom optimizers can optionally provide a custom_step()
function that is called instead
of the standard training step and should be defined as follows:
1class SomeOptimizer(torch.optim.Optimizer):
2 def custom_step(
3 self,
4 model: torch.nn.Module, # model
5 data: torch.Tensor, # batched input data
6 target: torch.Tensor, # batched target data
7 criterion: torch.nn.Module, # loss function
8 ) -> Tuple[float, torch.Tensor]:
9 loss = ... # reduced loss over the batch
10 outputs = ... # detached model outputs
11 return loss, outputs
Note
The custom_step()
function should perform both the forward and backward pass as well as update the model parameters accordingly.