Models#
autrainer provides a number of different audio-specific models as well as wrappers for torchvision and timm models.
Tip
To create custom models, refer to the custom models tutorial.
Default configurations that end in -T
indicate that the model uses transfer learning with pretrained weights.
To avoid race conditions when using Launcher Plugins that may run multiple training jobs in parallel,
autrainer fetch or fetch()
are used to download the pretrained weights before training.
Note
Most models are pretrained on the ImageNet or AudioSet datasets. To ensure compatibility with any number of output dimensions, the last linear layer of the model is replaced with a new linear layer with the correct number of output dimensions and will therefore not be pretrained.
The weights for all pretrained models that are provided by autrainer can be automatically downloaded using the
autrainer fetch CLI command or the fetch()
CLI wrapper function.
To optionally use model, optimizer, or scheduler checkpoints, the following attributes can be set in any model configuration file:
model_checkpoint
: The path to the model checkpoint file. Defaults to None.optimizer_checkpoint
: The path to the optimizer checkpoint file. Defaults to None.scheduler_checkpoint
: The path to the scheduler checkpoint file. Defaults to None.skip_last_layer
: Whether to skip loading the state of the last linear or convolutional layer. When set to True, the state of the last layer (if present) is omitted from both the model and optimizer, allowing for training on a different target dataset with varying output dimensions. Defaults to True.
Note
Loading a checkpoint assumes that the model architecture is the same as the one used to create the checkpoint and
that the last layer of the model is specified as the final Linear
or _ConvNd
module.
If the last layer is not the final layer in the module order, it may not be correctly identified for skipping.
Abstract Model#
All models inherit from the AbstractModel
class and implement the forward()
and embeddings()
methods.
Model Wrappers#
For convenience, we provide wrappers for torchvision and timm models.
- class autrainer.models.TorchvisionModel(output_dim, torchvision_name, transfer=False, **kwargs)[source]#
Wrapper for torchvision models.
- Parameters:
output_dim (
int
) – Number of output classes.torchvision_name (
str
) – Name of the model available in torchvision.models.transfer – Whether to load the model with pretrained weights. The “DEFAULT” weights are used if transfer is True. The final layer is replaced with a new layer with output_dim output features. Defaults to False.
kwargs – Additional arguments to pass to the model constructor.
Default Configurations
autrainer provides default configurations for all torchvision classification models. For more information on the available torchvision models as well as their parameters, refer to the torchvision classification models documentation.
By default, models using transfer learning (indicated by a trailing
-T
in the model name) use the default pretrained weights provided by torchvision.AlexNet
1id: AlexNet-T 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: alexnet 4transfer: true 5 6transform: 7 type: image
1id: AlexNet 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: alexnet 4transfer: false 5 6transform: 7 type: image
ConvNeXt
1id: ConvNeXt-Base-T 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: convnext_base 4transfer: true 5 6transform: 7 type: image
1id: ConvNeXt-Base 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: convnext_base 4transfer: false 5 6transform: 7 type: image
1id: ConvNeXt-Large-T 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: convnext_large 4transfer: true 5 6transform: 7 type: image
1id: ConvNeXt-Large 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: convnext_large 4transfer: false 5 6transform: 7 type: image
1id: ConvNeXt-Small-T 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: convnext_small 4transfer: true 5 6transform: 7 type: image
1id: ConvNeXt-Small 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: convnext_small 4transfer: false 5 6transform: 7 type: image
1id: ConvNeXt-Tiny-T 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: convnext_tiny 4transfer: true 5 6transform: 7 type: image
1id: ConvNeXt-Tiny 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: convnext_tiny 4transfer: false 5 6transform: 7 type: image
Densenet
1id: Densenet-121-T 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: densenet121 4transfer: true 5transform: 6 type: image
1id: Densenet-121 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: densenet121 4transfer: false 5transform: 6 type: image
EfficientNet
1id: EfficientNet-B0-T 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: efficientnet_b0 4transfer: true 5transform: 6 type: image
1id: EfficientNet-B0 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: efficientnet_b0 4transfer: false 5transform: 6 type: image
1id: EfficientNet-B1-T 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: efficientnet_b1 4transfer: true 5transform: 6 type: image
1id: EfficientNet-B1 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: efficientnet_b1 4transfer: false 5transform: 6 type: image
1id: EfficientNet-B2-T 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: efficientnet_b2 4transfer: true 5transform: 6 type: image
1id: EfficientNet-B2 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: efficientnet_b2 4transfer: false 5transform: 6 type: image
1id: EfficientNet-B3-T 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: efficientnet_b3 4transfer: true 5transform: 6 type: image
1id: EfficientNet-B3 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: efficientnet_b3 4transfer: false 5transform: 6 type: image
1id: EfficientNet-B4-T 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: efficientnet_b4 4transfer: true 5transform: 6 type: image
1id: EfficientNet-B4 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: efficientnet_b4 4transfer: false 5transform: 6 type: image
1id: EfficientNet-B5-T 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: efficientnet_b5 4transfer: true 5transform: 6 type: image
1id: EfficientNet-B5 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: efficientnet_b5 4transfer: false 5transform: 6 type: image
1id: EfficientNet-B6-T 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: efficientnet_b6 4transfer: true 5transform: 6 type: image
1id: EfficientNet-B6 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: efficientnet_b6 4transfer: false 5transform: 6 type: image
1id: EfficientNet-B7-T 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: efficientnet_b7 4transfer: true 5transform: 6 type: image
1id: EfficientNet-B7 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: efficientnet_b7 4transfer: false 5transform: 6 type: image
1id: EfficientNetV2-L-T 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: efficientnet_v2_l 4transfer: true 5transform: 6 type: image
1id: EfficientNetV2-L 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: efficientnet_v2_l 4transfer: false 5transform: 6 type: image
1id: EfficientNetV2-M-T 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: efficientnet_v2_m 4transfer: true 5transform: 6 type: image
1id: EfficientNetV2-M 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: efficientnet_v2_m 4transfer: false 5transform: 6 type: image
1id: EfficientNetV2-S-T 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: efficientnet_v2_s 4transfer: true 5transform: 6 type: image
1id: EfficientNetV2-S 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: efficientnet_v2_s 4transfer: false 5transform: 6 type: image
GoogLeNet
1id: GoogLeNet-T 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: googlenet 4transfer: true 5aux_logits: false 6transform: 7 type: image 8 base: 9 - autrainer.transforms.Resize: 10 height: 224 11 width: 224
1id: GoogLeNet 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: googlenet 4transfer: false 5aux_logits: false 6transform: 7 type: image 8 base: 9 - autrainer.transforms.Resize: 10 height: 224 11 width: 224
InceptionV3
1id: InceptionV3-T 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: inception_v3 4transfer: true 5aux_logits: false 6 7transform: 8 type: image 9 base: 10 - autrainer.transforms.Resize: 11 height: 299 12 width: 299
1id: InceptionV3 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: inception_v3 4transfer: false 5aux_logits: false 6 7transform: 8 type: image 9 base: 10 - autrainer.transforms.Resize: 11 height: 299 12 width: 299
MaxViT
1id: MaxViT-T-T 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: maxvit_t 4transfer: true 5transform: 6 type: image 7 base: 8 - autrainer.transforms.Resize: 9 height: 224 10 width: 224
1id: MaxViT-T 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: maxvit_t 4transfer: false 5transform: 6 type: image 7 base: 8 - autrainer.transforms.Resize: 9 height: 224 10 width: 224
MnasNet
1id: MnasNet-0.5-T 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: mnasnet0_5 4transfer: true 5transform: 6 type: image
1id: MnasNet-0.5 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: mnasnet0_5 4transfer: false 5transform: 6 type: image
1id: MnasNet-0.75-T 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: mnasnet0_75 4transfer: true 5transform: 6 type: image
1id: MnasNet-0.75 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: mnasnet0_75 4transfer: false 5transform: 6 type: image
1id: MnasNet-1.0-T 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: mnasnet1_0 4transfer: true 5transform: 6 type: image
1id: MnasNet-1.0 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: mnasnet1_0 4transfer: false 5transform: 6 type: image
1id: MnasNet-1.3-T 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: mnasnet1_3 4transfer: true 5transform: 6 type: image
1id: MnasNet-1.3 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: mnasnet1_3 4transfer: false 5transform: 6 type: image
MobileNet
1id: MobileNetV2-T 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: mobilenet_v2 4transfer: true 5transform: 6 type: image
1id: MobileNetV2 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: mobilenet_v2 4transfer: false 5transform: 6 type: image
1id: MobileNetV3-Large-T 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: mobilenet_v3_large 4transfer: true 5transform: 6 type: image
1id: MobileNetV3-Large 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: mobilenet_v3_large 4transfer: false 5transform: 6 type: image
1id: MobileNetV3-Small-T 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: mobilenet_v3_small 4transfer: true 5transform: 6 type: image
1id: MobileNetV3-Small 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: mobilenet_v3_small 4transfer: false 5transform: 6 type: image
RegNet
1id: RegNetX-1.6GF-T 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: regnet_x_1_6gf 4transfer: true 5transform: 6 type: image
1id: RegNetX-1.6GF 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: regnet_x_1_6gf 4transfer: false 5transform: 6 type: image
1id: RegNetX-16GF-T 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: regnet_x_16gf 4transfer: true 5transform: 6 type: image
1id: RegNetX-16GF 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: regnet_x_16gf 4transfer: false 5transform: 6 type: image
1id: RegNetX-3.2GF-T 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: regnet_x_3_2gf 4transfer: true 5transform: 6 type: image
1id: RegNetX-3.2GF 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: regnet_x_3_2gf 4transfer: false 5transform: 6 type: image
1id: RegNetX-32GF-T 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: regnet_x_32gf 4transfer: true 5transform: 6 type: image
1id: RegNetX-32GF 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: regnet_x_32gf 4transfer: false 5transform: 6 type: image
1id: RegNetX-400MF-T 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: regnet_x_400mf 4transfer: true 5transform: 6 type: image
1id: RegNetX-400MF 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: regnet_x_400mf 4transfer: false 5transform: 6 type: image
1id: RegNetX-800MF-T 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: regnet_x_800mf 4transfer: true 5transform: 6 type: image
1id: RegNetX-800MF 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: regnet_x_800mf 4transfer: false 5transform: 6 type: image
1id: RegNetX-8GF-T 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: regnet_x_8gf 4transfer: true 5transform: 6 type: image
1id: RegNetX-8GF 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: regnet_x_8gf 4transfer: false 5transform: 6 type: image
1id: RegNetY-1.6GF-T 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: regnet_y_1_6gf 4transfer: true 5transform: 6 type: image
1id: RegNetY-1.6GF 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: regnet_y_1_6gf 4transfer: false 5transform: 6 type: image
1id: RegNetY-128GF-T 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: regnet_y_128gf 4transfer: true 5transform: 6 type: image
1id: RegNetY-128GF 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: regnet_y_128gf 4transfer: false 5transform: 6 type: image
1id: RegNetY-16GF-T 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: regnet_y_16gf 4transfer: true 5transform: 6 type: image
1id: RegNetY-16GF 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: regnet_y_16gf 4transfer: false 5transform: 6 type: image
1id: RegNetY-3.2GF-T 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: regnet_y_3_2gf 4transfer: true 5transform: 6 type: image
1id: RegNetY-3.2GF 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: regnet_y_3_2gf 4transfer: false 5transform: 6 type: image
1id: RegNetY-32GF-T 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: regnet_y_32gf 4transfer: true 5transform: 6 type: image
1id: RegNetY-32GF 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: regnet_y_32gf 4transfer: false 5transform: 6 type: image
1id: RegNetY-400MF-T 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: regnet_y_400mf 4transfer: true 5transform: 6 type: image
1id: RegNetY-400MF 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: regnet_y_400mf 4transfer: false 5transform: 6 type: image
1id: RegNetY-800MF-T 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: regnet_y_800mf 4transfer: true 5transform: 6 type: image
1id: RegNetY-800MF 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: regnet_y_800mf 4transfer: false 5transform: 6 type: image
1id: RegNetY-8GF-T 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: regnet_y_8gf 4transfer: true 5transform: 6 type: image
1id: RegNetY-8GF 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: regnet_y_8gf 4transfer: false 5transform: 6 type: image
ResNet
1id: ResNet-101-T 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: resnet101 4transfer: true 5transform: 6 type: image
1id: ResNet-101 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: resnet101 4transfer: false 5transform: 6 type: image
1id: ResNet-152-T 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: resnet152 4transfer: true 5transform: 6 type: image
1id: ResNet-152 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: resnet152 4transfer: false 5transform: 6 type: image
1id: ResNet-18-T 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: resnet18 4transfer: true 5transform: 6 type: image
1id: ResNet-18 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: resnet18 4transfer: false 5transform: 6 type: image
1id: ResNet-34-T 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: resnet34 4transfer: true 5transform: 6 type: image
1id: ResNet-34 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: resnet34 4transfer: false 5transform: 6 type: image
1id: ResNet-50-T 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: resnet50 4transfer: true 5transform: 6 type: image
1id: ResNet-50 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: resnet50 4transfer: false 5transform: 6 type: image
ResNeXt
1id: ResNeXt-101-32x8d-T 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: resnext101_32x8d 4transfer: true 5transform: 6 type: image
1id: ResNeXt-101-32x8d 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: resnext101_32x8d 4transfer: false 5transform: 6 type: image
1id: ResNeXt-101-64x4d-T 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: resnext101_64x4d 4transfer: true 5transform: 6 type: image
1id: ResNeXt-101-64x4d 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: resnext101_64x4d 4transfer: false 5transform: 6 type: image
1id: ResNeXt-50-32x4d-T 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: resnext50_32x4d 4transfer: true 5transform: 6 type: image
1id: ResNeXt-50-32x4d 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: resnext50_32x4d 4transfer: false 5transform: 6 type: image
ShuffleNet
1id: ShuffleNetV2-0.5x-T 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: shufflenet_v2_x0_5 4transfer: true 5transform: 6 type: image
1id: ShuffleNetV2-0.5x 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: shufflenet_v2_x0_5 4transfer: false 5transform: 6 type: image
1id: ShuffleNetV2-1.0x-T 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: shufflenet_v2_x1_0 4transfer: true 5transform: 6 type: image
1id: ShuffleNetV2-1.0x 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: shufflenet_v2_x1_0 4transfer: false 5transform: 6 type: image
1id: ShuffleNetV2-1.5x-T 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: shufflenet_v2_x1_5 4transfer: true 5transform: 6 type: image
1id: ShuffleNetV2-1.5x 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: shufflenet_v2_x1_5 4transfer: false 5transform: 6 type: image
1id: ShuffleNetV2-2.0x-T 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: shufflenet_v2_x2_0 4transfer: true 5transform: 6 type: image
1id: ShuffleNetV2-2.0x 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: shufflenet_v2_x2_0 4transfer: false 5transform: 6 type: image
SqueezeNet
1id: SqueezeNet-1.0-T 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: squeezenet1_0 4transfer: true 5transform: 6 type: image
1id: SqueezeNet-1.0 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: squeezenet1_0 4transfer: false 5transform: 6 type: image
1id: SqueezeNet-1.1-T 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: squeezenet1_1 4transfer: true 5transform: 6 type: image
1id: SqueezeNet-1.1 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: squeezenet1_1 4transfer: false 5transform: 6 type: image
Swin
1id: Swin-B-T 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: swin_b 4transfer: true 5transform: 6 type: image 7 base: 8 - autrainer.transforms.Resize: 9 height: 224 10 width: 224
1id: Swin-B 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: swin_b 4transfer: false 5transform: 6 type: image 7 base: 8 - autrainer.transforms.Resize: 9 height: 224 10 width: 224
1id: Swin-S-T 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: swin_s 4transfer: true 5transform: 6 type: image 7 base: 8 - autrainer.transforms.Resize: 9 height: 224 10 width: 224
1id: Swin-S 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: swin_s 4transfer: false 5transform: 6 type: image 7 base: 8 - autrainer.transforms.Resize: 9 height: 224 10 width: 224
1id: Swin-T-T 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: swin_t 4transfer: true 5transform: 6 type: image 7 base: 8 - autrainer.transforms.Resize: 9 height: 224 10 width: 224
1id: Swin-T 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: swin_t 4transfer: false 5transform: 6 type: image 7 base: 8 - autrainer.transforms.Resize: 9 height: 224 10 width: 224
1id: Swin-V2-B-T 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: swin_v2_b 4transfer: true 5transform: 6 type: image 7 base: 8 - autrainer.transforms.Resize: 9 height: 256 10 width: 256
1id: Swin-V2-B 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: swin_v2_b 4transfer: false 5transform: 6 type: image 7 base: 8 - autrainer.transforms.Resize: 9 height: 256 10 width: 256
1id: Swin-V2-S-T 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: swin_v2_s 4transfer: true 5transform: 6 type: image 7 base: 8 - autrainer.transforms.Resize: 9 height: 256 10 width: 256
1id: Swin-V2-S 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: swin_v2_s 4transfer: false 5transform: 6 type: image 7 base: 8 - autrainer.transforms.Resize: 9 height: 256 10 width: 256
1id: Swin-V2-T-T 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: swin_v2_t 4transfer: true 5transform: 6 type: image 7 base: 8 - autrainer.transforms.Resize: 9 height: 256 10 width: 256
1id: Swin-V2-T 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: swin_v2_t 4transfer: false 5transform: 6 type: image 7 base: 8 - autrainer.transforms.Resize: 9 height: 256 10 width: 256
VGG
1id: VGG-11-BN-T 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: vgg11_bn 4transfer: true 5transform: 6 type: image
1id: VGG-11-BN 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: vgg11_bn 4transfer: false 5transform: 6 type: image
1id: VGG-11-T 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: vgg11 4transfer: true 5transform: 6 type: image
1id: VGG-11 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: vgg11 4transfer: false 5transform: 6 type: image
1id: VGG-13-BN-T 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: vgg13_bn 4transfer: true 5transform: 6 type: image
1id: VGG-13-BN 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: vgg13_bn 4transfer: false 5transform: 6 type: image
1id: VGG-13-T 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: vgg13 4transfer: true 5transform: 6 type: image
1id: VGG-13 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: vgg13 4transfer: false 5transform: 6 type: image
1id: VGG-16-BN-T 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: vgg16_bn 4transfer: true 5transform: 6 type: image
1id: VGG-16-BN 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: vgg16_bn 4transfer: false 5transform: 6 type: image
1id: VGG-16-T 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: vgg16 4transfer: true 5transform: 6 type: image
1id: VGG-16 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: vgg16 4transfer: false 5transform: 6 type: image
1id: VGG-19-BN-T 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: vgg19_bn 4transfer: true 5transform: 6 type: image
1id: VGG-19-BN 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: vgg19_bn 4transfer: false 5transform: 6 type: image
1id: VGG-19-T 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: vgg19 4transfer: true 5transform: 6 type: image
1id: VGG-19 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: vgg19 4transfer: false 5transform: 6 type: image
MaxViT
1id: MaxViT-T-T 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: maxvit_t 4transfer: true 5transform: 6 type: image 7 base: 8 - autrainer.transforms.Resize: 9 height: 224 10 width: 224
1id: MaxViT-T 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: maxvit_t 4transfer: false 5transform: 6 type: image 7 base: 8 - autrainer.transforms.Resize: 9 height: 224 10 width: 224
Wide-ResNet
1id: Wide-ResNet-101-2-T 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: wide_resnet101_2 4transfer: true 5transform: 6 type: image
1id: Wide-ResNet-101-2 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: wide_resnet101_2 4transfer: false 5transform: 6 type: image
1id: Wide-ResNet-50-2-T 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: wide_resnet50_2 4transfer: true 5transform: 6 type: image
1id: Wide-ResNet-50-2 2_target_: autrainer.models.TorchvisionModel 3torchvision_name: wide_resnet50_2 4transfer: false 5transform: 6 type: image
- class autrainer.models.TimmModel(output_dim, timm_name, transfer=False, **kwargs)[source]#
Wrapper for timm models.
- Parameters:
output_dim (
int
) – Number of output classes.timm_name (
str
) – Name of the model available in timm.create_model.transfer – Whether to load the model with pretrained weights. The final layer is replaced with a new layer with output_dim output features. Defaults to False.
kwargs – Additional arguments to pass to the model constructor.
Default Configurations
autrainer does not provide default configurations for timm models. To discover the available timm models to create a custom configuration, refer to the timm documentation.
Audio Models#
autrainer provides a number of different audio-specific models.
- class autrainer.models.ASTModel(output_dim, num_hidden_layers=12, hidden_size=128, dropout=0.5, transfer=None)[source]#
Audio Speech Transformer (AST) model. For more information see: https://huggingface.co/docs/transformers/v4.31.0/en/model_doc/audio-spectrogram-transformer
- Parameters:
output_dim (
int
) – Output dimension of the model.num_hidden_layers (
int
) – Number of hidden layers in the transformer. Defaults to 12.hidden_size (
int
) – Hidden size of the linear layer. Defaults to 128.dropout (
float
) – Dropout rate. Defaults to 0.5.transfer (
Optional
[str
]) – Name of the pretrained model to load. If None, the default AST fine-tuned on AudioSet is used. Defaults to None. For more information see: https://huggingface.co/MIT/ast-finetuned-audioset-10-10-0.4593
Default Configurations
1id: ASTModel-T 2_target_: autrainer.models.ASTModel 3transfer: MIT/ast-finetuned-audioset-10-10-0.4593 4 5transform: 6 type: raw 7 base: 8 - autrainer.transforms.FeatureExtractor: 9 fe_type: AST 10 fe_transfer: MIT/ast-finetuned-audioset-10-10-0.4593
- class autrainer.models.AudioRNNModel(output_dim, model_name, hidden_size=256, num_layers=2, dropout=0.5, cell='LSTM', bidirectional=False)[source]#
Audio RNN model.
- Parameters:
output_dim (
int
) – Output dimension of the model.model_name (
str
) – Model name in [“emo18”, “zhao19”].hidden_size (
int
) – Hidden size of the RNN. Defaults to 256.num_layers (
int
) – Number of layers of the RNN. Defaults to 2.dropout (
float
) – Dropout rate. Defaults to 0.5.cell (
str
) – Type of RNN cell in [“LSTM”, “GRU”] Defaults to “LSTM”.bidirectional (
bool
) – Whether to use a bidirectional RNN. Defaults to False.
Default Configurations
1id: End2You-emo18 2_target_: autrainer.models.AudioRNNModel 3model_name: emo18 4 5transform: 6 type: raw
1id: End2You-zhao19 2_target_: autrainer.models.AudioRNNModel 3model_name: zhao19 4 5transform: 6 type: raw
- class autrainer.models.Cnn10(output_dim, segmentwise=False, in_channels=1, transfer=None)[source]#
CNN10 PANN model. For more information see: https://doi.org/10.48550/arXiv.1912.10211
- Parameters:
output_dim (
int
) – Output dimension of the model.segmentwise (
bool
) – Whether to use segmentwise path or clipwise path. Defaults to False.in_channels (
int
) – Number of input channels. Defaults to 1.transfer (
Optional
[str
]) – Link to the weights to transfer. If None, the weights weights will be randomly initialized. Defaults to None.
Default Configurations
1id: Cnn10-32k-T-online 2_target_: autrainer.models.Cnn10 3transfer: https://zenodo.org/records/3987831/files/Cnn10_mAP%3D0.380.pth 4 5transform: 6 type: raw 7 base: 8 - autrainer.transforms.Resample: 9 current_sr: 48000 10 target_sr: 32000 11 - autrainer.transforms.PannMel: 12 sample_rate: 32000 13 window_size: 1024 14 hop_size: 320 15 mel_bins: 64 16 fmin: 50 17 fmax: 14000 18 ref: 1.0 19 amin: 1e-10 20 top_db: null
1id: Cnn10-32k-T 2_target_: autrainer.models.Cnn10 3transfer: https://zenodo.org/records/3987831/files/Cnn10_mAP%3D0.380.pth 4 5transform: 6 type: grayscale 7 base: 8 - autrainer.transforms.Normalize: null
1id: Cnn10 2_target_: autrainer.models.Cnn10 3 4transform: 5 type: grayscale 6 base: 7 - autrainer.transforms.Normalize: null
- class autrainer.models.Cnn14(output_dim, segmentwise=False, in_channels=1, transfer=None)[source]#
CNN14 PANN model. For more information see: https://doi.org/10.48550/arXiv.1912.10211
- Parameters:
output_dim (
int
) – Output dimension of the model.segmentwise (
bool
) – Whether to use segmentwise path or clipwise path. Defaults to False.in_channels (
int
) – Number of input channels. Defaults to 1.transfer (
Optional
[str
]) – Link to the weights to transfer. If None, the weights weights will be randomly initialized. Defaults to None.
Default Configurations
1id: Cnn14-16k-T 2_target_: autrainer.models.Cnn14 3transfer: https://zenodo.org/records/3987831/files/Cnn14_16k_mAP%3D0.438.pth 4 5transform: 6 type: grayscale 7 base: 8 - autrainer.transforms.Normalize: null
1id: Cnn14-32k-T 2_target_: autrainer.models.Cnn14 3transfer: https://zenodo.org/records/3987831/files/Cnn14_mAP%3D0.431.pth 4 5transform: 6 type: grayscale 7 base: 8 - autrainer.transforms.Normalize: null
1id: Cnn14 2_target_: autrainer.models.Cnn14 3 4transform: 5 type: grayscale 6 base: 7 - autrainer.transforms.Normalize: null
- class autrainer.models.FFNN(output_dim, input_size, hidden_size, num_layers=2, dropout=0.5)[source]#
Feedforward neural network.
- Parameters:
output_dim (
int
) – Output dimension.input_size (
int
) – Input size.hidden_size (
int
) – Hidden size.num_layers (
int
) – Number of layers.dropout (
float
) – Dropout rate.
Default Configurations
1id: ToyFFNN 2_target_: autrainer.models.FFNN 3input_size: 64 4hidden_size: 64 5num_layers: 2 6 7transform: 8 type: tabular
- class autrainer.models.LEAFNet(output_dim, leaf_filters=40, kernel_size=25, stride=0.0625, window_stride=10, padding_kernel_size=25, sample_rate=16000, min_freq=60, max_freq=7800, efficientnet_type='efficientnet_b0', mode='interspeech', initialization='mel', generator_seed=42, transfer=False)[source]#
EfficientNet with LEAF-Is frontend. Used to reproduce work from: https://www.isca-archive.org/interspeech_2023/meng23c_interspeech.html
Also see original LEAF and PCEN papers (c.f. speechbrain).
We take and slightly adapt the LEAF frontend implementation from: Hanyu-Meng/Adapting-LEAF
- Parameters:
output_dim (
int
) – Output dimension.leaf_filters (
int
) – Number of LEAF filterbanks to train. Defaults to 40.kernel_size (
int
) – Size of kernels applied by LEAF (in ms). Defaults to 25.stride (
float
) – Stride of LEAF (in ms). Defaults to 0.0625.window_stride (
int
) – Stride of lowpass filter (in ms). Defaults to 10.padding_kernel_size (
int
) – Size of lowpass filter (in ms). Defaults to 25.sample_rate (
int
) – Used to compute LEAF params. Defaults to 16000.min_freq (
int
) – Minimum freq analyzed by LEAF. Defaults to 60.max_freq (
int
) – Maximum freq analyzed by LEAF. Defaults to 7800.efficientnet_type (
str
) – EfficientNet type to use from timm. Defaults to “efficientnet_b0”.mode (
str
) – Implementation according to “interspeech” paper (Meng et al.) or “speech_brain”. Defaults to “interspeech”.initialization (
str
) – Filterbank initialisation in [“mel”, “bark”, “linear-constant”, “constant”, “uniform”, “zeros”]. Defaults to “mel”.generator_seed (
int
) – Seed for random generator. Defaults to 42.transfer (
bool
) – Whether to use EfficientNet weights from ImageNet. Defaults to False.
- Raises:
ValueError – If efficientnet_type is not supported.
ValueError – If mode is not supported.
Default Configurations
1id: LEAFNet 2_target_: autrainer.models.LEAFNet 3 4transform: 5 type: raw
- class autrainer.models.SeqFFNN(output_dim, backbone_input_dim, backbone_hidden_size, backbone_num_layers, hidden_size, num_layers=2, dropout=0.5, backbone_dropout=0.5, backbone_cell='LSTM', backbone_time_pooling=True, backbone_bidirectional=False)[source]#
Sequential model with FFNN frontend.
- Parameters:
output_dim (
int
) – Output dimension of the FFNN.backbone_input_dim (
int
) – Input dimension of the backbone.backbone_hidden_size (
int
) – Hidden size of the backbone.backbone_num_layers (
int
) – Number of layers of the backbone.hidden_size (
int
) – Hidden size of the FFNN.num_layers (
int
) – Number of layers of the FFNN. Defaults to 2.dropout (
float
) – Dropout rate of the FFNN. Defaults to 0.5.backbone_dropout (
float
) – Dropout rate of the backbone. Defaults to 0.5.backbone_cell (
str
) – Cell type of the backbone in [“LSTM”, “GRU”]. Defaults to “LSTM”.backbone_time_pooling (
bool
) – Whether to apply time pooling in the backbone. Defaults to True.backbone_bidirectional (
bool
) – Whether to use a bidirectional backbone. Defaults to False.
Default Configurations
1id: Seq-FFNN-eGeMAPS 2_target_: autrainer.models.SeqFFNN 3backbone_input_dim: 25 4backbone_cell: LSTM 5backbone_hidden_size: 32 6backbone_num_layers: 2 7hidden_size: 32 8num_layers: 2 9dropout: 0.5 10 11transform: 12 type: tabular
1id: Seq-FFNN-IS09 2_target_: autrainer.models.SeqFFNN 3backbone_input_dim: 32 4backbone_cell: LSTM 5backbone_hidden_size: 32 6backbone_num_layers: 2 7hidden_size: 32 8num_layers: 2 9dropout: 0.5 10 11transform: 12 type: tabular
1id: Seq-FFNN-IS10 2_target_: autrainer.models.SeqFFNN 3backbone_input_dim: 76 4backbone_cell: LSTM 5backbone_hidden_size: 32 6backbone_num_layers: 2 7hidden_size: 32 8num_layers: 2 9dropout: 0.5 10 11transform: 12 type: tabular
1id: Seq-FFNN-IS11 2_target_: autrainer.models.SeqFFNN 3backbone_input_dim: 118 4backbone_cell: LSTM 5backbone_hidden_size: 32 6backbone_num_layers: 2 7hidden_size: 32 8num_layers: 2 9dropout: 0.5 10 11transform: 12 type: tabular
1id: Seq-FFNN-IS12 2_target_: autrainer.models.SeqFFNN 3backbone_input_dim: 128 4backbone_cell: LSTM 5backbone_hidden_size: 32 6backbone_num_layers: 2 7hidden_size: 32 8num_layers: 2 9dropout: 0.5 10 11transform: 12 type: tabular
1id: Seq-FFNN-IS13 2_target_: autrainer.models.SeqFFNN 3backbone_input_dim: 130 4backbone_cell: LSTM 5backbone_hidden_size: 32 6backbone_num_layers: 2 7hidden_size: 32 8num_layers: 2 9dropout: 0.5 10 11transform: 12 type: tabular
1id: Seq-FFNN-IS16 2_target_: autrainer.models.SeqFFNN 3backbone_input_dim: 130 4backbone_cell: LSTM 5backbone_hidden_size: 32 6backbone_num_layers: 2 7hidden_size: 32 8num_layers: 2 9dropout: 0.5 10 11transform: 12 type: tabular
- class autrainer.models.TDNNFFNN(output_dim, hidden_size, num_layers=2, dropout=0.5)[source]#
Time Delay Neural Network with FFNN frontend.
- Parameters:
output_dim (
int
) – Output dimension.hidden_size (
int
) – Hidden size.num_layers (
int
) – Number of layers. Defaults to 2.dropout (
float
) – Dropout rate. Defaults to 0.5.
Default Configurations
1id: TDNNFFNN-T 2_target_: autrainer.models.TDNNFFNN 3hidden_size: 32 4 5transform: 6 type: raw
- class autrainer.models.W2V2FFNN(output_dim, model_name, freeze_extractor, hidden_size, num_layers=2, dropout=0.5)[source]#
Wav2Vec2 model with FFNN frontend adapted for audio classification. For more information, see: https://huggingface.co/docs/transformers/model_doc/wav2vec2
- Parameters:
output_dim (
int
) – Output dimension of the FFNN.model_name (
str
) – Name of the model loaded from Huggingface.freeze_extractor (
bool
) – Whether to freeze the feature extractor.hidden_size (
int
) – Hidden size of the FFNN.num_layers (
int
) – Number of layers of the FFNN. Defaults to 2.dropout (
float
) – Dropout rate. Defaults to 0.5.
Default Configurations
1id: w2v2-b 2_target_: autrainer.models.W2V2FFNN 3model_name: facebook/wav2vec2-base 4freeze_extractor: true 5hidden_size: 512 6num_layers: 2 7dropout: 0.5 8 9transform: 10 type: raw 11 base: 12 - autrainer.transforms.FeatureExtractor: 13 fe_type: W2V2 14 fe_transfer: facebook/wav2vec2-base
1id: w2v2-l-100k 2_target_: autrainer.models.W2V2FFNN 3model_name: facebook/wav2vec2-large-100k-voxpopuli 4freeze_extractor: true 5hidden_size: 512 6num_layers: 2 7dropout: 0.5 8 9transform: 10 type: raw 11 base: 12 - autrainer.transforms.FeatureExtractor: 13 fe_type: W2V2 14 fe_transfer: facebook/wav2vec2-large-100k-voxpopuli
1id: w2v2-l-emo 2_target_: autrainer.models.W2V2FFNN 3model_name: audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim 4freeze_extractor: true 5hidden_size: 512 6num_layers: 2 7dropout: 0.5 8 9transform: 10 type: raw 11 base: 12 - autrainer.transforms.FeatureExtractor: 13 fe_type: W2V2 14 fe_transfer: audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim
1id: w2v2-l-rob 2_target_: autrainer.models.W2V2FFNN 3model_name: facebook/wav2vec2-large-robust 4freeze_extractor: true 5hidden_size: 512 6num_layers: 2 7dropout: 0.5 8 9transform: 10 type: raw 11 base: 12 - autrainer.transforms.FeatureExtractor: 13 fe_type: W2V2 14 fe_transfer: facebook/wav2vec2-large-robust
1id: w2v2-l 2_target_: autrainer.models.W2V2FFNN 3model_name: facebook/wav2vec2-large 4freeze_extractor: true 5hidden_size: 512 6num_layers: 2 7dropout: 0.5 8 9transform: 10 type: raw 11 base: 12 - autrainer.transforms.FeatureExtractor: 13 fe_type: W2V2 14 fe_transfer: facebook/wav2vec2-large
- class autrainer.models.WhisperFFNN(output_dim, model_name, hidden_size, num_layers=2, dropout=0.5)[source]#
Whisper model with FFNN frontend adapted for audio classification. For more information, see: https://doi.org/10.48550/arXiv.2212.04356
- Parameters:
model_name (
str
) – Name of the model loaded from Huggingface.hidden_size (
int
) – Hidden size of the FFNN.output_dim (
int
) – Output dimension of the FFNN.num_layers (
int
) – Number of layers of the FFNN. Defaults to 2.dropout (
float
) – Dropout rate. Defaults to 0.5.
Default Configurations
1id: Whisper-FFNN-Base-T 2_target_: autrainer.models.WhisperFFNN 3model_name: openai/whisper-base 4hidden_size: 512 5num_layers: 2 6dropout: 0.5 7 8transform: 9 type: raw 10 base: 11 - autrainer.transforms.FeatureExtractor: 12 fe_type: Whisper 13 fe_transfer: openai/whisper-base
1id: Whisper-FFNN-Small-T 2_target_: autrainer.models.WhisperFFNN 3model_name: openai/whisper-small 4hidden_size: 512 5num_layers: 2 6dropout: 0.5 7 8transform: 9 type: raw 10 base: 11 - autrainer.transforms.FeatureExtractor: 12 fe_type: Whisper 13 fe_transfer: openai/whisper-small
1id: Whisper-FFNN-Tiny-T 2_target_: autrainer.models.WhisperFFNN 3model_name: openai/whisper-tiny 4hidden_size: 512 5num_layers: 2 6dropout: 0.5 7 8transform: 9 type: raw 10 base: 11 - autrainer.transforms.FeatureExtractor: 12 fe_type: Whisper 13 fe_transfer: openai/whisper-tiny