Optimizers#
Any torch optimizer or custom optimizer can be used for training.
Tip
To create custom optimizers, refer to the custom optimizers tutorial.
Torch Optimizers#
Torch optimizers (torch.optim
) can be specified as relative python import paths for the _target_
attribute in the configuration file.
Any additional attributes (except id
) are passed as keyword arguments to the optimizer constructor.
For example, torch.optim.Adam
can be used as follows:
1id: Adam
2_target_: torch.optim.Adam
autrainer provides a number of default configurations for torch optimizers:
Default Configurations
1id: SGD
2_target_: torch.optim.SGD
1id: Adam
2_target_: torch.optim.Adam
1id: AdamW
2_target_: torch.optim.AdamW
1id: Adadelta
2_target_: torch.optim.Adadelta
1id: Adagrad
2_target_: torch.optim.Adagrad
1id: Adamax
2_target_: torch.optim.Adamax
1id: NAdam
2_target_: torch.optim.NAdam
1id: RAdam
2_target_: torch.optim.RAdam
1id: SparseAdam
2_target_: torch.optim.SparseAdam
1id: RMSprop
2_target_: torch.optim.RMSprop
1id: Rprop
2_target_: torch.optim.Rprop
1id: ASGD
2_target_: torch.optim.ASGD
1id: LBFGS
2_target_: torch.optim.LBFGS
Custom Optimizers#
- class autrainer.optimizers.SAM(params, base_optimizer, rho=0.05, adaptive=False, **kwargs)[source]#
Sharpness Aware Minimization (SAM) optimizer.
This implementation is adapted from the following repository: davda54/sam
For more information, see: https://arxiv.org/abs/2010.01412
- Parameters:
params (
Module
) – Model parameters.base_optimizer (
str
) – Underlying optimizer performing the sharpness-aware update, specified as a relative import path.rho (
float
) – Size of the neighborhood for computing the max loss. Defaults to 0.05.adaptive (
bool
) – Whether to use an experimental implementation of element-wise Adaptive SAM. Defaults to False.**kwargs – Additional arguments passed to the underlying optimizer.
Default Configurations
conf/optimizer/SAM-SGD.yaml#1id: SAM-SGD 2_target_: autrainer.optimizers.SAM 3base_optimizer: torch.optim.SGD
Custom Step Function#
Custom optimizers can optionally provide a custom_step()
function that is called instead
of the standard training step and should be defined as follows:
1class SomeOptimizer(torch.optim.Optimizer):
2 def custom_step(
3 self,
4 model: AbstractModel,
5 data: DataBatch,
6 criterion: torch.nn.Module,
7 probabilities_fn: Callable,
8 ) -> Tuple[torch.Tensor, torch.Tensor]:
9 """Custom step function for the optimizer.
10
11 Args:
12 model: The model to train.
13 data: The data batch containing features, target, and potentially
14 additional fields. The data batch is on the same
15 device as the model. Additional fields are passed to the model
16 as keyword arguments if they are present in the model's forward
17 method.
18 criterion: Loss function.
19 probabilities_fn: Function to convert model outputs to
20 probabilities.
21
22 Returns:
23 Tuple containing the non-reduced loss and model outputs.
24 """
Note
The custom_step()
function should perform both the forward and backward pass as well as update the model parameters accordingly.