Optimizers#
Any torch optimizer or custom optimizer can be used for training.
Tip
To create custom optimizers, refer to the custom optimizers tutorial.
Torch Optimizers#
Torch optimizers (torch.optim
) can be specified as relative python import paths for the _target_
attribute in the configuration file.
Any additional attributes (except id
) are passed as keyword arguments to the optimizer constructor.
For example, torch.optim.Adam
can be used as follows:
1id: Adam
2_target_: torch.optim.Adam
autrainer provides a number of default configurations for torch optimizers:
Default Configurations
Adam
1id: Adam
2_target_: torch.optim.Adam
1id: Adamax
2_target_: torch.optim.Adamax
1id: AdamW
2_target_: torch.optim.AdamW
SGD
1id: SGD
2_target_: torch.optim.SGD
Custom Optimizers#
- class autrainer.optimizers.SAM(params, base_optimizer, rho=0.05, adaptive=False, **kwargs)[source]#
Sharpness Aware Minimization (SAM) optimizer.
This implementation is adapted from the following repository: davda54/sam
For more information, see: https://arxiv.org/abs/2010.01412
- Parameters:
params (
Module
) – Model parameters.base_optimizer (
str
) – Underlying optimizer performing the sharpness-aware update, specified as a relative import path.rho (
float
) – Size of the neighborhood for computing the max loss. Defaults to 0.05.adaptive (
bool
) – Whether to use an experimental implementation of element-wise Adaptive SAM. Defaults to False.**kwargs – Additional arguments passed to the underlying optimizer.
Default Configurations
1id: SAM-SGD 2_target_: autrainer.optimizers.SAM 3base_optimizer: torch.optim.SGD
Custom Step Function#
Custom optimizers can optionally provide a custom_step()
function that is called instead
of the standard training step and should be defined as follows:
1class SomeOptimizer(torch.optim.Optimizer):
2 def custom_step(
3 self,
4 model: torch.nn.Module,
5 data: torch.Tensor,
6 target: torch.Tensor,
7 criterion: torch.nn.Module,
8 probabilities_fn: Callable,
9 ) -> Tuple[float, torch.Tensor]:
10 """Custom step function for the optimizer.
11
12 Args:
13 model: Model to be optimized.
14 data: Batched input data.
15 target: Batched target data.
16 criterion: Loss function.
17 probabilities_fn: Function to get probabilities from model outputs.
18
19 Returns:
20 Reduced loss over the batch and detached model outputs.
21 """
Note
The custom_step()
function should perform both the forward and backward pass as well as update the model parameters accordingly.