Models#

autrainer provides a number of different audio-specific models as well as wrappers for torchvision and timm models.

Tip

To create custom models, refer to the custom models tutorial.

Default configurations that end in -T indicate that the model uses transfer learning with pretrained weights. To avoid race conditions when using Launcher Plugins that may run multiple training jobs in parallel, autrainer fetch or fetch() are used to download the pretrained weights before training.

Note

Most models are pretrained on the ImageNet or AudioSet datasets. To ensure compatibility with any number of output dimensions, the last linear layer of the model is replaced with a new linear layer with the correct number of output dimensions and will therefore not be pretrained.

The weights for all pretrained models that are provided by autrainer can be automatically downloaded using the autrainer fetch CLI command or the fetch() CLI wrapper function.

To optionally use model, optimizer, or scheduler checkpoints, the following attributes can be set in any model configuration file:

  • model_checkpoint: The path to the model checkpoint file. Defaults to None.

  • optimizer_checkpoint: The path to the optimizer checkpoint file. Defaults to None.

  • scheduler_checkpoint: The path to the scheduler checkpoint file. Defaults to None.

  • skip_last_layer: Whether to skip loading the state of the last linear or convolutional layer. When set to True, the state of the last layer (if present) is omitted from both the model and optimizer, allowing for training on a different target dataset with varying output dimensions. Defaults to True.

Note

Loading a checkpoint assumes that the model architecture is the same as the one used to create the checkpoint and that the last layer of the model is specified as the final Linear or _ConvNd module. If the last layer is not the final layer in the module order, it may not be correctly identified for skipping.

Abstract Model#

All models inherit from the AbstractModel class and implement the forward() and embeddings() methods.

class autrainer.models.AbstractModel(output_dim)[source]#

Abstract model class.

Parameters:

output_dim (int) – Output dimension of the model.

abstract forward(x)[source]#

Forward pass of the model.

Parameters:

x (Tensor) – Input tensor.

Return type:

Tensor

Returns:

Output tensor.

abstract embeddings(x)[source]#

Get embeddings from the model.

Parameters:

x (Tensor) – Input tensor.

Return type:

Tensor

Returns:

Embeddings.

Model Wrappers#

For convenience, we provide wrappers for torchvision and timm models.

class autrainer.models.TorchvisionModel(output_dim, torchvision_name, transfer=False, **kwargs)[source]#

Wrapper for torchvision models.

Parameters:
  • output_dim (int) – Number of output classes.

  • torchvision_name (str) – Name of the model available in torchvision.models.

  • transfer – Whether to load the model with pretrained weights. The “DEFAULT” weights are used if transfer is True. The final layer is replaced with a new layer with output_dim output features. Defaults to False.

  • kwargs – Additional arguments to pass to the model constructor.

Default Configurations

autrainer provides default configurations for all torchvision classification models. For more information on the available torchvision models as well as their parameters, refer to the torchvision classification models documentation.

By default, models using transfer learning (indicated by a trailing -T in the model name) use the default pretrained weights provided by torchvision.

AlexNet

conf/model/AlexNet-T.yaml#
1id: AlexNet-T
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: alexnet
4transfer: true
5
6transform:
7  type: image
conf/model/AlexNet.yaml#
1id: AlexNet
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: alexnet
4transfer: false
5
6transform:
7  type: image

ConvNeXt

conf/model/ConvNeXt-Base-T.yaml#
1id: ConvNeXt-Base-T
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: convnext_base
4transfer: true
5
6transform:
7  type: image
conf/model/ConvNeXt-Base.yaml#
1id: ConvNeXt-Base
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: convnext_base
4transfer: false
5
6transform:
7  type: image
conf/model/ConvNeXt-Large-T.yaml#
1id: ConvNeXt-Large-T
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: convnext_large
4transfer: true
5
6transform:
7  type: image
conf/model/ConvNeXt-Large.yaml#
1id: ConvNeXt-Large
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: convnext_large
4transfer: false
5
6transform:
7  type: image
conf/model/ConvNeXt-Small-T.yaml#
1id: ConvNeXt-Small-T
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: convnext_small
4transfer: true
5
6transform:
7  type: image
conf/model/ConvNeXt-Small.yaml#
1id: ConvNeXt-Small
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: convnext_small
4transfer: false
5
6transform:
7  type: image
conf/model/ConvNeXt-Tiny-T.yaml#
1id: ConvNeXt-Tiny-T
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: convnext_tiny
4transfer: true
5
6transform:
7  type: image
conf/model/ConvNeXt-Tiny.yaml#
1id: ConvNeXt-Tiny
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: convnext_tiny
4transfer: false
5
6transform:
7  type: image

Densenet

conf/model/Densenet-121-T.yaml#
1id: Densenet-121-T
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: densenet121
4transfer: true
5transform:
6  type: image
conf/model/Densenet-121.yaml#
1id: Densenet-121
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: densenet121
4transfer: false
5transform:
6  type: image

EfficientNet

conf/model/EfficientNet-B0-T.yaml#
1id: EfficientNet-B0-T
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: efficientnet_b0
4transfer: true
5transform:
6  type: image
conf/model/EfficientNet-B0.yaml#
1id: EfficientNet-B0
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: efficientnet_b0
4transfer: false
5transform:
6  type: image
conf/model/EfficientNet-B1-T.yaml#
1id: EfficientNet-B1-T
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: efficientnet_b1
4transfer: true
5transform:
6  type: image
conf/model/EfficientNet-B1.yaml#
1id: EfficientNet-B1
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: efficientnet_b1
4transfer: false
5transform:
6  type: image
conf/model/EfficientNet-B2-T.yaml#
1id: EfficientNet-B2-T
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: efficientnet_b2
4transfer: true
5transform:
6  type: image
conf/model/EfficientNet-B2.yaml#
1id: EfficientNet-B2
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: efficientnet_b2
4transfer: false
5transform:
6  type: image
conf/model/EfficientNet-B3-T.yaml#
1id: EfficientNet-B3-T
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: efficientnet_b3
4transfer: true
5transform:
6  type: image
conf/model/EfficientNet-B3.yaml#
1id: EfficientNet-B3
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: efficientnet_b3
4transfer: false
5transform:
6  type: image
conf/model/EfficientNet-B4-T.yaml#
1id: EfficientNet-B4-T
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: efficientnet_b4
4transfer: true
5transform:
6  type: image
conf/model/EfficientNet-B4.yaml#
1id: EfficientNet-B4
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: efficientnet_b4
4transfer: false
5transform:
6  type: image
conf/model/EfficientNet-B5-T.yaml#
1id: EfficientNet-B5-T
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: efficientnet_b5
4transfer: true
5transform:
6  type: image
conf/model/EfficientNet-B5.yaml#
1id: EfficientNet-B5
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: efficientnet_b5
4transfer: false
5transform:
6  type: image
conf/model/EfficientNet-B6-T.yaml#
1id: EfficientNet-B6-T
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: efficientnet_b6
4transfer: true
5transform:
6  type: image
conf/model/EfficientNet-B6.yaml#
1id: EfficientNet-B6
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: efficientnet_b6
4transfer: false
5transform:
6  type: image
conf/model/EfficientNet-B7-T.yaml#
1id: EfficientNet-B7-T
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: efficientnet_b7
4transfer: true
5transform:
6  type: image
conf/model/EfficientNet-B7.yaml#
1id: EfficientNet-B7
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: efficientnet_b7
4transfer: false
5transform:
6  type: image
conf/model/EfficientNetV2-L-T.yaml#
1id: EfficientNetV2-L-T
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: efficientnet_v2_l
4transfer: true
5transform:
6  type: image
conf/model/EfficientNetV2-L.yaml#
1id: EfficientNetV2-L
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: efficientnet_v2_l
4transfer: false
5transform:
6  type: image
conf/model/EfficientNetV2-M-T.yaml#
1id: EfficientNetV2-M-T
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: efficientnet_v2_m
4transfer: true
5transform:
6  type: image
conf/model/EfficientNetV2-M.yaml#
1id: EfficientNetV2-M
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: efficientnet_v2_m
4transfer: false
5transform:
6  type: image
conf/model/EfficientNetV2-S-T.yaml#
1id: EfficientNetV2-S-T
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: efficientnet_v2_s
4transfer: true
5transform:
6  type: image
conf/model/EfficientNetV2-S.yaml#
1id: EfficientNetV2-S
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: efficientnet_v2_s
4transfer: false
5transform:
6  type: image

GoogLeNet

conf/model/GoogLeNet-T.yaml#
 1id: GoogLeNet-T
 2_target_: autrainer.models.TorchvisionModel
 3torchvision_name: googlenet
 4transfer: true
 5aux_logits: false
 6transform:
 7  type: image
 8  base:
 9    - autrainer.transforms.Resize:  
10        height: 224
11        width: 224
conf/model/GoogLeNet.yaml#
 1id: GoogLeNet
 2_target_: autrainer.models.TorchvisionModel
 3torchvision_name: googlenet
 4transfer: false
 5aux_logits: false
 6transform:
 7  type: image
 8  base:
 9    - autrainer.transforms.Resize:  
10        height: 224
11        width: 224

InceptionV3

conf/model/InceptionV3-T.yaml#
 1id: InceptionV3-T
 2_target_: autrainer.models.TorchvisionModel
 3torchvision_name: inception_v3
 4transfer: true
 5aux_logits: false
 6
 7transform:
 8  type: image
 9  base:
10    - autrainer.transforms.Resize:  
11        height: 299
12        width: 299
conf/model/InceptionV3.yaml#
 1id: InceptionV3
 2_target_: autrainer.models.TorchvisionModel
 3torchvision_name: inception_v3
 4transfer: false
 5aux_logits: false
 6
 7transform:
 8  type: image
 9  base:
10    - autrainer.transforms.Resize:  
11        height: 299
12        width: 299

MaxViT

conf/model/MaxViT-T-T.yaml#
 1id: MaxViT-T-T
 2_target_: autrainer.models.TorchvisionModel
 3torchvision_name: maxvit_t
 4transfer: true
 5transform:
 6  type: image
 7  base:
 8    - autrainer.transforms.Resize:  
 9        height: 224
10        width: 224
conf/model/MaxViT-T.yaml#
 1id: MaxViT-T
 2_target_: autrainer.models.TorchvisionModel
 3torchvision_name: maxvit_t
 4transfer: false
 5transform:
 6  type: image
 7  base:
 8    - autrainer.transforms.Resize:  
 9        height: 224
10        width: 224

MnasNet

conf/model/MnasNet-0.5-T.yaml#
1id: MnasNet-0.5-T
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: mnasnet0_5
4transfer: true
5transform:
6  type: image
conf/model/MnasNet-0.5.yaml#
1id: MnasNet-0.5
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: mnasnet0_5
4transfer: false
5transform:
6  type: image
conf/model/MnasNet-0.75-T.yaml#
1id: MnasNet-0.75-T
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: mnasnet0_75
4transfer: true
5transform:
6  type: image
conf/model/MnasNet-0.75.yaml#
1id: MnasNet-0.75
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: mnasnet0_75
4transfer: false
5transform:
6  type: image
conf/model/MnasNet-1.0-T.yaml#
1id: MnasNet-1.0-T
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: mnasnet1_0
4transfer: true
5transform:
6  type: image
conf/model/MnasNet-1.0.yaml#
1id: MnasNet-1.0
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: mnasnet1_0
4transfer: false
5transform:
6  type: image
conf/model/MnasNet-1.3-T.yaml#
1id: MnasNet-1.3-T
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: mnasnet1_3
4transfer: true
5transform:
6  type: image
conf/model/MnasNet-1.3.yaml#
1id: MnasNet-1.3
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: mnasnet1_3
4transfer: false
5transform:
6  type: image

MobileNet

conf/model/MobileNetV2-T.yaml#
1id: MobileNetV2-T
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: mobilenet_v2
4transfer: true
5transform:
6  type: image
conf/model/MobileNetV2.yaml#
1id: MobileNetV2
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: mobilenet_v2
4transfer: false
5transform:
6  type: image
conf/model/MobileNetV3-Large-T.yaml#
1id: MobileNetV3-Large-T
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: mobilenet_v3_large
4transfer: true
5transform:
6  type: image
conf/model/MobileNetV3-Large.yaml#
1id: MobileNetV3-Large
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: mobilenet_v3_large
4transfer: false
5transform:
6  type: image
conf/model/MobileNetV3-Small-T.yaml#
1id: MobileNetV3-Small-T
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: mobilenet_v3_small
4transfer: true
5transform:
6  type: image
conf/model/MobileNetV3-Small.yaml#
1id: MobileNetV3-Small
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: mobilenet_v3_small
4transfer: false
5transform:
6  type: image

RegNet

conf/model/RegNetX-1.6GF-T.yaml#
1id: RegNetX-1.6GF-T
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: regnet_x_1_6gf
4transfer: true
5transform:
6  type: image
conf/model/RegNetX-1.6GF.yaml#
1id: RegNetX-1.6GF
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: regnet_x_1_6gf
4transfer: false
5transform:
6  type: image
conf/model/RegNetX-16GF-T.yaml#
1id: RegNetX-16GF-T
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: regnet_x_16gf
4transfer: true
5transform:
6  type: image
conf/model/RegNetX-16GF.yaml#
1id: RegNetX-16GF
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: regnet_x_16gf
4transfer: false
5transform:
6  type: image
conf/model/RegNetX-3.2GF-T.yaml#
1id: RegNetX-3.2GF-T
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: regnet_x_3_2gf
4transfer: true
5transform:
6  type: image
conf/model/RegNetX-3.2GF.yaml#
1id: RegNetX-3.2GF
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: regnet_x_3_2gf
4transfer: false
5transform:
6  type: image
conf/model/RegNetX-32GF-T.yaml#
1id: RegNetX-32GF-T
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: regnet_x_32gf
4transfer: true
5transform:
6  type: image
conf/model/RegNetX-32GF.yaml#
1id: RegNetX-32GF
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: regnet_x_32gf
4transfer: false
5transform:
6  type: image
conf/model/RegNetX-400MF-T.yaml#
1id: RegNetX-400MF-T
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: regnet_x_400mf
4transfer: true
5transform:
6  type: image
conf/model/RegNetX-400MF.yaml#
1id: RegNetX-400MF
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: regnet_x_400mf
4transfer: false
5transform:
6  type: image
conf/model/RegNetX-800MF-T.yaml#
1id: RegNetX-800MF-T
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: regnet_x_800mf
4transfer: true
5transform:
6  type: image
conf/model/RegNetX-800MF.yaml#
1id: RegNetX-800MF
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: regnet_x_800mf
4transfer: false
5transform:
6  type: image
conf/model/RegNetX-8GF-T.yaml#
1id: RegNetX-8GF-T
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: regnet_x_8gf
4transfer: true
5transform:
6  type: image
conf/model/RegNetX-8GF.yaml#
1id: RegNetX-8GF
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: regnet_x_8gf
4transfer: false
5transform:
6  type: image
conf/model/RegNetY-1.6GF-T.yaml#
1id: RegNetY-1.6GF-T
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: regnet_y_1_6gf
4transfer: true
5transform:
6  type: image
conf/model/RegNetY-1.6GF.yaml#
1id: RegNetY-1.6GF
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: regnet_y_1_6gf
4transfer: false
5transform:
6  type: image
conf/model/RegNetY-128GF-T.yaml#
1id: RegNetY-128GF-T
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: regnet_y_128gf
4transfer: true
5transform:
6  type: image
conf/model/RegNetY-128GF.yaml#
1id: RegNetY-128GF
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: regnet_y_128gf
4transfer: false
5transform:
6  type: image
conf/model/RegNetY-16GF-T.yaml#
1id: RegNetY-16GF-T
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: regnet_y_16gf
4transfer: true
5transform:
6  type: image
conf/model/RegNetY-16GF.yaml#
1id: RegNetY-16GF
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: regnet_y_16gf
4transfer: false
5transform:
6  type: image
conf/model/RegNetY-3.2GF-T.yaml#
1id: RegNetY-3.2GF-T
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: regnet_y_3_2gf
4transfer: true
5transform:
6  type: image
conf/model/RegNetY-3.2GF.yaml#
1id: RegNetY-3.2GF
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: regnet_y_3_2gf
4transfer: false
5transform:
6  type: image
conf/model/RegNetY-32GF-T.yaml#
1id: RegNetY-32GF-T
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: regnet_y_32gf
4transfer: true
5transform:
6  type: image
conf/model/RegNetY-32GF.yaml#
1id: RegNetY-32GF
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: regnet_y_32gf
4transfer: false
5transform:
6  type: image
conf/model/RegNetY-400MF-T.yaml#
1id: RegNetY-400MF-T
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: regnet_y_400mf
4transfer: true
5transform:
6  type: image
conf/model/RegNetY-400MF.yaml#
1id: RegNetY-400MF
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: regnet_y_400mf
4transfer: false
5transform:
6  type: image
conf/model/RegNetY-800MF-T.yaml#
1id: RegNetY-800MF-T
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: regnet_y_800mf
4transfer: true
5transform:
6  type: image
conf/model/RegNetY-800MF.yaml#
1id: RegNetY-800MF
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: regnet_y_800mf
4transfer: false
5transform:
6  type: image
conf/model/RegNetY-8GF-T.yaml#
1id: RegNetY-8GF-T
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: regnet_y_8gf
4transfer: true
5transform:
6  type: image
conf/model/RegNetY-8GF.yaml#
1id: RegNetY-8GF
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: regnet_y_8gf
4transfer: false
5transform:
6  type: image

ResNet

conf/model/ResNet-101-T.yaml#
1id: ResNet-101-T
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: resnet101
4transfer: true
5transform:
6  type: image
conf/model/ResNet-101.yaml#
1id: ResNet-101
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: resnet101
4transfer: false
5transform:
6  type: image
conf/model/ResNet-152-T.yaml#
1id: ResNet-152-T
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: resnet152
4transfer: true
5transform:
6  type: image
conf/model/ResNet-152.yaml#
1id: ResNet-152
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: resnet152
4transfer: false
5transform:
6  type: image
conf/model/ResNet-18-T.yaml#
1id: ResNet-18-T
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: resnet18
4transfer: true
5transform:
6  type: image
conf/model/ResNet-18.yaml#
1id: ResNet-18
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: resnet18
4transfer: false
5transform:
6  type: image
conf/model/ResNet-34-T.yaml#
1id: ResNet-34-T
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: resnet34
4transfer: true
5transform:
6  type: image
conf/model/ResNet-34.yaml#
1id: ResNet-34
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: resnet34
4transfer: false
5transform:
6  type: image
conf/model/ResNet-50-T.yaml#
1id: ResNet-50-T
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: resnet50
4transfer: true
5transform:
6  type: image
conf/model/ResNet-50.yaml#
1id: ResNet-50
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: resnet50
4transfer: false
5transform:
6  type: image

ResNeXt

conf/model/ResNeXt-101-32x8d-T.yaml#
1id: ResNeXt-101-32x8d-T
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: resnext101_32x8d
4transfer: true
5transform:
6  type: image
conf/model/ResNeXt-101-32x8d.yaml#
1id: ResNeXt-101-32x8d
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: resnext101_32x8d
4transfer: false
5transform:
6  type: image
conf/model/ResNeXt-101-64x4d-T.yaml#
1id: ResNeXt-101-64x4d-T
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: resnext101_64x4d
4transfer: true
5transform:
6  type: image
conf/model/ResNeXt-101-64x4d.yaml#
1id: ResNeXt-101-64x4d
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: resnext101_64x4d
4transfer: false
5transform:
6  type: image
conf/model/ResNeXt-50-32x4d-T.yaml#
1id: ResNeXt-50-32x4d-T
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: resnext50_32x4d
4transfer: true
5transform:
6  type: image
conf/model/ResNeXt-50-32x4d.yaml#
1id: ResNeXt-50-32x4d
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: resnext50_32x4d
4transfer: false
5transform:
6  type: image

ShuffleNet

conf/model/ShuffleNetV2-0.5x-T.yaml#
1id: ShuffleNetV2-0.5x-T
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: shufflenet_v2_x0_5
4transfer: true
5transform:
6  type: image
conf/model/ShuffleNetV2-0.5x.yaml#
1id: ShuffleNetV2-0.5x
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: shufflenet_v2_x0_5
4transfer: false
5transform:
6  type: image
conf/model/ShuffleNetV2-1.0x-T.yaml#
1id: ShuffleNetV2-1.0x-T
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: shufflenet_v2_x1_0
4transfer: true
5transform:
6  type: image
conf/model/ShuffleNetV2-1.0x.yaml#
1id: ShuffleNetV2-1.0x
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: shufflenet_v2_x1_0
4transfer: false
5transform:
6  type: image
conf/model/ShuffleNetV2-1.5x-T.yaml#
1id: ShuffleNetV2-1.5x-T
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: shufflenet_v2_x1_5
4transfer: true
5transform:
6  type: image
conf/model/ShuffleNetV2-1.5x.yaml#
1id: ShuffleNetV2-1.5x
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: shufflenet_v2_x1_5
4transfer: false
5transform:
6  type: image
conf/model/ShuffleNetV2-2.0x-T.yaml#
1id: ShuffleNetV2-2.0x-T
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: shufflenet_v2_x2_0
4transfer: true
5transform:
6  type: image
conf/model/ShuffleNetV2-2.0x.yaml#
1id: ShuffleNetV2-2.0x
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: shufflenet_v2_x2_0
4transfer: false
5transform:
6  type: image

SqueezeNet

conf/model/SqueezeNet-1.0-T.yaml#
1id: SqueezeNet-1.0-T
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: squeezenet1_0
4transfer: true
5transform:
6  type: image
conf/model/SqueezeNet-1.0.yaml#
1id: SqueezeNet-1.0
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: squeezenet1_0
4transfer: false
5transform:
6  type: image
conf/model/SqueezeNet-1.1-T.yaml#
1id: SqueezeNet-1.1-T
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: squeezenet1_1
4transfer: true
5transform:
6  type: image
conf/model/SqueezeNet-1.1.yaml#
1id: SqueezeNet-1.1
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: squeezenet1_1
4transfer: false
5transform:
6  type: image

Swin

conf/model/Swin-B-T.yaml#
 1id: Swin-B-T
 2_target_: autrainer.models.TorchvisionModel
 3torchvision_name: swin_b
 4transfer: true
 5transform:
 6  type: image
 7  base:
 8    - autrainer.transforms.Resize:  
 9        height: 224
10        width: 224
conf/model/Swin-B.yaml#
 1id: Swin-B
 2_target_: autrainer.models.TorchvisionModel
 3torchvision_name: swin_b
 4transfer: false
 5transform:
 6  type: image
 7  base:
 8    - autrainer.transforms.Resize:  
 9        height: 224
10        width: 224
conf/model/Swin-S-T.yaml#
 1id: Swin-S-T
 2_target_: autrainer.models.TorchvisionModel
 3torchvision_name: swin_s
 4transfer: true
 5transform:
 6  type: image
 7  base:
 8    - autrainer.transforms.Resize:  
 9        height: 224
10        width: 224
conf/model/Swin-S.yaml#
 1id: Swin-S
 2_target_: autrainer.models.TorchvisionModel
 3torchvision_name: swin_s
 4transfer: false
 5transform:
 6  type: image
 7  base:
 8    - autrainer.transforms.Resize:  
 9        height: 224
10        width: 224
conf/model/Swin-T-T.yaml#
 1id: Swin-T-T
 2_target_: autrainer.models.TorchvisionModel
 3torchvision_name: swin_t
 4transfer: true
 5transform:
 6  type: image
 7  base:
 8    - autrainer.transforms.Resize:  
 9        height: 224
10        width: 224
conf/model/Swin-T.yaml#
 1id: Swin-T
 2_target_: autrainer.models.TorchvisionModel
 3torchvision_name: swin_t
 4transfer: false
 5transform:
 6  type: image
 7  base:
 8    - autrainer.transforms.Resize:  
 9        height: 224
10        width: 224
conf/model/Swin-V2-B-T.yaml#
 1id: Swin-V2-B-T
 2_target_: autrainer.models.TorchvisionModel
 3torchvision_name: swin_v2_b
 4transfer: true
 5transform:
 6  type: image
 7  base:
 8    - autrainer.transforms.Resize:  
 9        height: 256
10        width: 256
conf/model/Swin-V2-B.yaml#
 1id: Swin-V2-B
 2_target_: autrainer.models.TorchvisionModel
 3torchvision_name: swin_v2_b
 4transfer: false
 5transform:
 6  type: image
 7  base:
 8    - autrainer.transforms.Resize:  
 9        height: 256
10        width: 256
conf/model/Swin-V2-S-T.yaml#
 1id: Swin-V2-S-T
 2_target_: autrainer.models.TorchvisionModel
 3torchvision_name: swin_v2_s
 4transfer: true
 5transform:
 6  type: image
 7  base:
 8    - autrainer.transforms.Resize:  
 9        height: 256
10        width: 256
conf/model/Swin-V2-S.yaml#
 1id: Swin-V2-S
 2_target_: autrainer.models.TorchvisionModel
 3torchvision_name: swin_v2_s
 4transfer: false
 5transform:
 6  type: image
 7  base:
 8    - autrainer.transforms.Resize:  
 9        height: 256
10        width: 256
conf/model/Swin-V2-T-T.yaml#
 1id: Swin-V2-T-T
 2_target_: autrainer.models.TorchvisionModel
 3torchvision_name: swin_v2_t
 4transfer: true
 5transform:
 6  type: image
 7  base:
 8    - autrainer.transforms.Resize:  
 9        height: 256
10        width: 256
conf/model/Swin-V2-T.yaml#
 1id: Swin-V2-T
 2_target_: autrainer.models.TorchvisionModel
 3torchvision_name: swin_v2_t
 4transfer: false
 5transform:
 6  type: image
 7  base:
 8    - autrainer.transforms.Resize:  
 9        height: 256
10        width: 256

VGG

conf/model/VGG-11-BN-T.yaml#
1id: VGG-11-BN-T
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: vgg11_bn
4transfer: true
5transform:
6  type: image
conf/model/VGG-11-BN.yaml#
1id: VGG-11-BN
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: vgg11_bn
4transfer: false
5transform:
6  type: image
conf/model/VGG-11-T.yaml#
1id: VGG-11-T
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: vgg11
4transfer: true
5transform:
6  type: image
conf/model/VGG-11.yaml#
1id: VGG-11
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: vgg11
4transfer: false
5transform:
6  type: image
conf/model/VGG-13-BN-T.yaml#
1id: VGG-13-BN-T
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: vgg13_bn
4transfer: true
5transform:
6  type: image
conf/model/VGG-13-BN.yaml#
1id: VGG-13-BN
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: vgg13_bn
4transfer: false
5transform:
6  type: image
conf/model/VGG-13-T.yaml#
1id: VGG-13-T
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: vgg13
4transfer: true
5transform:
6  type: image
conf/model/VGG-13.yaml#
1id: VGG-13
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: vgg13
4transfer: false
5transform:
6  type: image
conf/model/VGG-16-BN-T.yaml#
1id: VGG-16-BN-T
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: vgg16_bn
4transfer: true
5transform:
6  type: image
conf/model/VGG-16-BN.yaml#
1id: VGG-16-BN
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: vgg16_bn
4transfer: false
5transform:
6  type: image
conf/model/VGG-16-T.yaml#
1id: VGG-16-T
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: vgg16
4transfer: true
5transform:
6  type: image
conf/model/VGG-16.yaml#
1id: VGG-16
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: vgg16
4transfer: false
5transform:
6  type: image
conf/model/VGG-19-BN-T.yaml#
1id: VGG-19-BN-T
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: vgg19_bn
4transfer: true
5transform:
6  type: image
conf/model/VGG-19-BN.yaml#
1id: VGG-19-BN
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: vgg19_bn
4transfer: false
5transform:
6  type: image
conf/model/VGG-19-T.yaml#
1id: VGG-19-T
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: vgg19
4transfer: true
5transform:
6  type: image
conf/model/VGG-19.yaml#
1id: VGG-19
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: vgg19
4transfer: false
5transform:
6  type: image

MaxViT

conf/model/MaxViT-T-T.yaml#
 1id: MaxViT-T-T
 2_target_: autrainer.models.TorchvisionModel
 3torchvision_name: maxvit_t
 4transfer: true
 5transform:
 6  type: image
 7  base:
 8    - autrainer.transforms.Resize:  
 9        height: 224
10        width: 224
conf/model/MaxViT-T.yaml#
 1id: MaxViT-T
 2_target_: autrainer.models.TorchvisionModel
 3torchvision_name: maxvit_t
 4transfer: false
 5transform:
 6  type: image
 7  base:
 8    - autrainer.transforms.Resize:  
 9        height: 224
10        width: 224

Wide-ResNet

conf/model/Wide-ResNet-101-2-T.yaml#
1id: Wide-ResNet-101-2-T
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: wide_resnet101_2
4transfer: true
5transform:
6  type: image
conf/model/Wide-ResNet-101-2.yaml#
1id: Wide-ResNet-101-2
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: wide_resnet101_2
4transfer: false
5transform:
6  type: image
conf/model/Wide-ResNet-50-2-T.yaml#
1id: Wide-ResNet-50-2-T
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: wide_resnet50_2
4transfer: true
5transform:
6  type: image
conf/model/Wide-ResNet-50-2.yaml#
1id: Wide-ResNet-50-2
2_target_: autrainer.models.TorchvisionModel
3torchvision_name: wide_resnet50_2
4transfer: false
5transform:
6  type: image
class autrainer.models.TimmModel(output_dim, timm_name, transfer=False, **kwargs)[source]#

Wrapper for timm models.

Parameters:
  • output_dim (int) – Number of output classes.

  • timm_name (str) – Name of the model available in timm.create_model.

  • transfer – Whether to load the model with pretrained weights. The final layer is replaced with a new layer with output_dim output features. Defaults to False.

  • kwargs – Additional arguments to pass to the model constructor.

Default Configurations

autrainer does not provide default configurations for timm models. To discover the available timm models to create a custom configuration, refer to the timm documentation.

Audio Models#

autrainer provides a number of different audio-specific models.

class autrainer.models.ASTModel(output_dim, num_hidden_layers=12, hidden_size=128, dropout=0.5, transfer=None)[source]#

Audio Speech Transformer (AST) model. For more information see: https://huggingface.co/docs/transformers/v4.31.0/en/model_doc/audio-spectrogram-transformer

Parameters:
  • output_dim (int) – Output dimension of the model.

  • num_hidden_layers (int) – Number of hidden layers in the transformer. Defaults to 12.

  • hidden_size (int) – Hidden size of the linear layer. Defaults to 128.

  • dropout (float) – Dropout rate. Defaults to 0.5.

  • transfer (Optional[str]) – Name of the pretrained model to load. If None, the default AST fine-tuned on AudioSet is used. Defaults to None. For more information see: https://huggingface.co/MIT/ast-finetuned-audioset-10-10-0.4593

Default Configurations
conf/model/ASTModel-T.yaml#
 1id: ASTModel-T
 2_target_: autrainer.models.ASTModel
 3transfer: MIT/ast-finetuned-audioset-10-10-0.4593
 4
 5transform:
 6  type: raw
 7  base:
 8    - autrainer.transforms.FeatureExtractor:
 9        fe_type: AST
10        fe_transfer: MIT/ast-finetuned-audioset-10-10-0.4593
class autrainer.models.AudioRNNModel(output_dim, model_name, hidden_size=256, num_layers=2, dropout=0.5, cell='LSTM', bidirectional=False)[source]#

Audio RNN model.

Parameters:
  • output_dim (int) – Output dimension of the model.

  • model_name (str) – Model name in [“emo18”, “zhao19”].

  • hidden_size (int) – Hidden size of the RNN. Defaults to 256.

  • num_layers (int) – Number of layers of the RNN. Defaults to 2.

  • dropout (float) – Dropout rate. Defaults to 0.5.

  • cell (str) – Type of RNN cell in [“LSTM”, “GRU”] Defaults to “LSTM”.

  • bidirectional (bool) – Whether to use a bidirectional RNN. Defaults to False.

Default Configurations
conf/model/End2You-emo18.yaml#
1id: End2You-emo18
2_target_: autrainer.models.AudioRNNModel
3model_name: emo18
4
5transform:
6  type: raw
conf/model/End2You-zhao19.yaml#
1id: End2You-zhao19
2_target_: autrainer.models.AudioRNNModel
3model_name: zhao19
4
5transform:
6  type: raw
class autrainer.models.Cnn10(output_dim, segmentwise=False, in_channels=1, transfer=None)[source]#

CNN10 PANN model. For more information see: https://doi.org/10.48550/arXiv.1912.10211

Parameters:
  • output_dim (int) – Output dimension of the model.

  • segmentwise (bool) – Whether to use segmentwise path or clipwise path. Defaults to False.

  • in_channels (int) – Number of input channels. Defaults to 1.

  • transfer (Optional[str]) – Link to the weights to transfer. If None, the weights weights will be randomly initialized. Defaults to None.

Default Configurations
conf/model/Cnn10-32k-T-online.yaml#
 1id: Cnn10-32k-T-online
 2_target_: autrainer.models.Cnn10
 3transfer: https://zenodo.org/records/3987831/files/Cnn10_mAP%3D0.380.pth
 4
 5transform:
 6  type: raw
 7  base:
 8    - autrainer.transforms.Resample:
 9        current_sr: 48000
10        target_sr: 32000
11    - autrainer.transforms.PannMel:
12        sample_rate: 32000
13        window_size: 1024
14        hop_size: 320
15        mel_bins: 64
16        fmin: 50
17        fmax: 14000
18        ref: 1.0
19        amin: 1e-10
20        top_db: null
conf/model/Cnn10-32k-T.yaml#
1id: Cnn10-32k-T
2_target_: autrainer.models.Cnn10
3transfer: https://zenodo.org/records/3987831/files/Cnn10_mAP%3D0.380.pth
4
5transform:
6  type: grayscale
7  base:
8    - autrainer.transforms.Normalize: null
conf/model/Cnn10.yaml#
1id: Cnn10
2_target_: autrainer.models.Cnn10
3
4transform:
5  type: grayscale
6  base:
7    - autrainer.transforms.Normalize: null
class autrainer.models.Cnn14(output_dim, segmentwise=False, in_channels=1, transfer=None)[source]#

CNN14 PANN model. For more information see: https://doi.org/10.48550/arXiv.1912.10211

Parameters:
  • output_dim (int) – Output dimension of the model.

  • segmentwise (bool) – Whether to use segmentwise path or clipwise path. Defaults to False.

  • in_channels (int) – Number of input channels. Defaults to 1.

  • transfer (Optional[str]) – Link to the weights to transfer. If None, the weights weights will be randomly initialized. Defaults to None.

Default Configurations
conf/model/Cnn14-16k-T.yaml#
1id: Cnn14-16k-T
2_target_: autrainer.models.Cnn14
3transfer: https://zenodo.org/records/3987831/files/Cnn14_16k_mAP%3D0.438.pth
4
5transform:
6  type: grayscale
7  base:
8    - autrainer.transforms.Normalize: null
conf/model/Cnn14-32k-T.yaml#
1id: Cnn14-32k-T
2_target_: autrainer.models.Cnn14
3transfer: https://zenodo.org/records/3987831/files/Cnn14_mAP%3D0.431.pth
4
5transform:
6  type: grayscale
7  base:
8    - autrainer.transforms.Normalize: null
conf/model/Cnn14.yaml#
1id: Cnn14
2_target_: autrainer.models.Cnn14
3
4transform:
5  type: grayscale
6  base:
7    - autrainer.transforms.Normalize: null
class autrainer.models.FFNN(output_dim, input_size, hidden_size, num_layers=2, dropout=0.5)[source]#

Feedforward neural network.

Parameters:
  • output_dim (int) – Output dimension.

  • input_size (int) – Input size.

  • hidden_size (int) – Hidden size.

  • num_layers (int) – Number of layers.

  • dropout (float) – Dropout rate.

Default Configurations
conf/model/ToyFFNN.yaml#
1id: ToyFFNN
2_target_: autrainer.models.FFNN
3input_size: 64
4hidden_size: 64
5num_layers: 2
6
7transform:
8  type: tabular
class autrainer.models.LEAFNet(output_dim, leaf_filters=40, kernel_size=25, stride=0.0625, window_stride=10, padding_kernel_size=25, sample_rate=16000, min_freq=60, max_freq=7800, efficientnet_type='efficientnet_b0', mode='interspeech', initialization='mel', generator_seed=42, transfer=False)[source]#

EfficientNet with LEAF-Is frontend. Used to reproduce work from: https://www.isca-archive.org/interspeech_2023/meng23c_interspeech.html

Also see original LEAF and PCEN papers (c.f. speechbrain).

We take and slightly adapt the LEAF frontend implementation from: Hanyu-Meng/Adapting-LEAF

Parameters:
  • output_dim (int) – Output dimension.

  • leaf_filters (int) – Number of LEAF filterbanks to train. Defaults to 40.

  • kernel_size (int) – Size of kernels applied by LEAF (in ms). Defaults to 25.

  • stride (float) – Stride of LEAF (in ms). Defaults to 0.0625.

  • window_stride (int) – Stride of lowpass filter (in ms). Defaults to 10.

  • padding_kernel_size (int) – Size of lowpass filter (in ms). Defaults to 25.

  • sample_rate (int) – Used to compute LEAF params. Defaults to 16000.

  • min_freq (int) – Minimum freq analyzed by LEAF. Defaults to 60.

  • max_freq (int) – Maximum freq analyzed by LEAF. Defaults to 7800.

  • efficientnet_type (str) – EfficientNet type to use from timm. Defaults to “efficientnet_b0”.

  • mode (str) – Implementation according to “interspeech” paper (Meng et al.) or “speech_brain”. Defaults to “interspeech”.

  • initialization (str) – Filterbank initialisation in [“mel”, “bark”, “linear-constant”, “constant”, “uniform”, “zeros”]. Defaults to “mel”.

  • generator_seed (int) – Seed for random generator. Defaults to 42.

  • transfer (bool) – Whether to use EfficientNet weights from ImageNet. Defaults to False.

Raises:
  • ValueError – If efficientnet_type is not supported.

  • ValueError – If mode is not supported.

Default Configurations
conf/model/LEAFNet.yaml#
1id: LEAFNet
2_target_: autrainer.models.LEAFNet
3
4transform:
5  type: raw
class autrainer.models.SeqFFNN(output_dim, backbone_input_dim, backbone_hidden_size, backbone_num_layers, hidden_size, num_layers=2, dropout=0.5, backbone_dropout=0.5, backbone_cell='LSTM', backbone_time_pooling=True, backbone_bidirectional=False)[source]#

Sequential model with FFNN frontend.

Parameters:
  • output_dim (int) – Output dimension of the FFNN.

  • backbone_input_dim (int) – Input dimension of the backbone.

  • backbone_hidden_size (int) – Hidden size of the backbone.

  • backbone_num_layers (int) – Number of layers of the backbone.

  • hidden_size (int) – Hidden size of the FFNN.

  • num_layers (int) – Number of layers of the FFNN. Defaults to 2.

  • dropout (float) – Dropout rate of the FFNN. Defaults to 0.5.

  • backbone_dropout (float) – Dropout rate of the backbone. Defaults to 0.5.

  • backbone_cell (str) – Cell type of the backbone in [“LSTM”, “GRU”]. Defaults to “LSTM”.

  • backbone_time_pooling (bool) – Whether to apply time pooling in the backbone. Defaults to True.

  • backbone_bidirectional (bool) – Whether to use a bidirectional backbone. Defaults to False.

Default Configurations
conf/model/Seq-FFNN-eGeMAPS.yaml#
 1id: Seq-FFNN-eGeMAPS
 2_target_: autrainer.models.SeqFFNN
 3backbone_input_dim: 25
 4backbone_cell: LSTM
 5backbone_hidden_size: 32
 6backbone_num_layers: 2
 7hidden_size: 32
 8num_layers: 2
 9dropout: 0.5
10
11transform:
12  type: tabular
conf/model/Seq-FFNN-IS09.yaml#
 1id: Seq-FFNN-IS09
 2_target_: autrainer.models.SeqFFNN
 3backbone_input_dim: 32
 4backbone_cell: LSTM
 5backbone_hidden_size: 32
 6backbone_num_layers: 2
 7hidden_size: 32
 8num_layers: 2
 9dropout: 0.5
10
11transform:
12  type: tabular
conf/model/Seq-FFNN-IS10.yaml#
 1id: Seq-FFNN-IS10
 2_target_: autrainer.models.SeqFFNN
 3backbone_input_dim: 76
 4backbone_cell: LSTM
 5backbone_hidden_size: 32
 6backbone_num_layers: 2
 7hidden_size: 32
 8num_layers: 2
 9dropout: 0.5
10
11transform:
12  type: tabular
conf/model/Seq-FFNN-IS11.yaml#
 1id: Seq-FFNN-IS11
 2_target_: autrainer.models.SeqFFNN
 3backbone_input_dim: 118
 4backbone_cell: LSTM
 5backbone_hidden_size: 32
 6backbone_num_layers: 2
 7hidden_size: 32
 8num_layers: 2
 9dropout: 0.5
10
11transform:
12  type: tabular
conf/model/Seq-FFNN-IS12.yaml#
 1id: Seq-FFNN-IS12
 2_target_: autrainer.models.SeqFFNN
 3backbone_input_dim: 128
 4backbone_cell: LSTM
 5backbone_hidden_size: 32
 6backbone_num_layers: 2
 7hidden_size: 32
 8num_layers: 2
 9dropout: 0.5
10
11transform:
12  type: tabular
conf/model/Seq-FFNN-IS13.yaml#
 1id: Seq-FFNN-IS13
 2_target_: autrainer.models.SeqFFNN
 3backbone_input_dim: 130
 4backbone_cell: LSTM
 5backbone_hidden_size: 32
 6backbone_num_layers: 2
 7hidden_size: 32
 8num_layers: 2
 9dropout: 0.5
10
11transform:
12  type: tabular
conf/model/Seq-FFNN-IS16.yaml#
 1id: Seq-FFNN-IS16
 2_target_: autrainer.models.SeqFFNN
 3backbone_input_dim: 130
 4backbone_cell: LSTM
 5backbone_hidden_size: 32
 6backbone_num_layers: 2
 7hidden_size: 32
 8num_layers: 2
 9dropout: 0.5
10
11transform:
12  type: tabular
class autrainer.models.TDNNFFNN(output_dim, hidden_size, num_layers=2, dropout=0.5)[source]#

Time Delay Neural Network with FFNN frontend.

Parameters:
  • output_dim (int) – Output dimension.

  • hidden_size (int) – Hidden size.

  • num_layers (int) – Number of layers. Defaults to 2.

  • dropout (float) – Dropout rate. Defaults to 0.5.

Default Configurations
conf/model/TDNNFFNN-T.yaml#
1id: TDNNFFNN-T
2_target_: autrainer.models.TDNNFFNN
3hidden_size: 32
4
5transform:
6  type: raw
class autrainer.models.W2V2FFNN(output_dim, model_name, freeze_extractor, hidden_size, num_layers=2, dropout=0.5)[source]#

Wav2Vec2 model with FFNN frontend adapted for audio classification. For more information, see: https://huggingface.co/docs/transformers/model_doc/wav2vec2

Parameters:
  • output_dim (int) – Output dimension of the FFNN.

  • model_name (str) – Name of the model loaded from Huggingface.

  • freeze_extractor (bool) – Whether to freeze the feature extractor.

  • hidden_size (int) – Hidden size of the FFNN.

  • num_layers (int) – Number of layers of the FFNN. Defaults to 2.

  • dropout (float) – Dropout rate. Defaults to 0.5.

Default Configurations
conf/model/w2v2-b.yaml#
 1id: w2v2-b
 2_target_: autrainer.models.W2V2FFNN
 3model_name: facebook/wav2vec2-base
 4freeze_extractor: true
 5hidden_size: 512
 6num_layers: 2
 7dropout: 0.5
 8
 9transform:
10  type: raw
11  base:
12    - autrainer.transforms.FeatureExtractor:
13        fe_type: W2V2
14        fe_transfer: facebook/wav2vec2-base
conf/model/w2v2-l-100k.yaml#
 1id: w2v2-l-100k
 2_target_: autrainer.models.W2V2FFNN
 3model_name: facebook/wav2vec2-large-100k-voxpopuli
 4freeze_extractor: true
 5hidden_size: 512
 6num_layers: 2
 7dropout: 0.5
 8
 9transform:
10  type: raw
11  base:
12    - autrainer.transforms.FeatureExtractor:
13        fe_type: W2V2
14        fe_transfer: facebook/wav2vec2-large-100k-voxpopuli
conf/model/w2v2-l-emo.yaml#
 1id: w2v2-l-emo
 2_target_: autrainer.models.W2V2FFNN
 3model_name: audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim
 4freeze_extractor: true
 5hidden_size: 512
 6num_layers: 2
 7dropout: 0.5
 8
 9transform:
10  type: raw
11  base:
12    - autrainer.transforms.FeatureExtractor:
13        fe_type: W2V2
14        fe_transfer: audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim
conf/model/w2v2-l-rob.yaml#
 1id: w2v2-l-rob
 2_target_: autrainer.models.W2V2FFNN
 3model_name: facebook/wav2vec2-large-robust
 4freeze_extractor: true
 5hidden_size: 512
 6num_layers: 2
 7dropout: 0.5
 8
 9transform:
10  type: raw
11  base:
12    - autrainer.transforms.FeatureExtractor:
13        fe_type: W2V2
14        fe_transfer: facebook/wav2vec2-large-robust
conf/model/w2v2-l.yaml#
 1id: w2v2-l
 2_target_: autrainer.models.W2V2FFNN
 3model_name: facebook/wav2vec2-large
 4freeze_extractor: true
 5hidden_size: 512
 6num_layers: 2
 7dropout: 0.5
 8
 9transform:
10  type: raw
11  base:
12    - autrainer.transforms.FeatureExtractor:
13        fe_type: W2V2
14        fe_transfer: facebook/wav2vec2-large
class autrainer.models.WhisperFFNN(output_dim, model_name, hidden_size, num_layers=2, dropout=0.5)[source]#

Whisper model with FFNN frontend adapted for audio classification. For more information, see: https://doi.org/10.48550/arXiv.2212.04356

Parameters:
  • model_name (str) – Name of the model loaded from Huggingface.

  • hidden_size (int) – Hidden size of the FFNN.

  • output_dim (int) – Output dimension of the FFNN.

  • num_layers (int) – Number of layers of the FFNN. Defaults to 2.

  • dropout (float) – Dropout rate. Defaults to 0.5.

Default Configurations
conf/model/Whisper-FFNN-Base-T.yaml#
 1id: Whisper-FFNN-Base-T
 2_target_: autrainer.models.WhisperFFNN
 3model_name: openai/whisper-base
 4hidden_size: 512
 5num_layers: 2
 6dropout: 0.5
 7
 8transform:
 9  type: raw
10  base:
11    - autrainer.transforms.FeatureExtractor:
12        fe_type: Whisper
13        fe_transfer: openai/whisper-base
conf/model/Whisper-FFNN-Small-T.yaml#
 1id: Whisper-FFNN-Small-T
 2_target_: autrainer.models.WhisperFFNN
 3model_name: openai/whisper-small
 4hidden_size: 512
 5num_layers: 2
 6dropout: 0.5
 7
 8transform:
 9  type: raw
10  base:
11    - autrainer.transforms.FeatureExtractor:
12        fe_type: Whisper
13        fe_transfer: openai/whisper-small
conf/model/Whisper-FFNN-Tiny-T.yaml#
 1id: Whisper-FFNN-Tiny-T
 2_target_: autrainer.models.WhisperFFNN
 3model_name: openai/whisper-tiny
 4hidden_size: 512
 5num_layers: 2
 6dropout: 0.5
 7
 8transform:
 9  type: raw
10  base:
11    - autrainer.transforms.FeatureExtractor:
12        fe_type: Whisper
13        fe_transfer: openai/whisper-tiny