Tutorials#

autrainer is designed to be flexible and extensible, allowing for the creation of custom …

For each, a tutorial is provided below to demonstrate their implementation and configuration.

For the following tutorials, all python files should be placed in the project root directory and all configuration files should be placed in the corresponding subdirectories of the conf/ directory.

Custom Models#

To create a custom model, inherit from AbstractModel and implement the forward() and embeddings() methods. All arguments of the constructor have to be assigned to a variable with the same name, as AbstractModel inherits from audobject.

For example, the following model is a simple CNN that takes a spectrogram as input and has a variable number of hidden CNN layers with a different number of filters each:

spectrogram_cnn.py#
 1from typing import List
 2
 3import torch
 4
 5from autrainer.models import AbstractModel
 6
 7
 8class SpectrogramCNN(AbstractModel):
 9    def __init__(self, output_dim: int, hidden_dims: List[int]) -> None:
10        """Spectrogram CNN model with a variable number of hidden CNN layers.
11
12        Args:
13            output_dim: Output dimension of the model.
14            hidden_dims: List of hidden dimensions for the CNN layers.
15        """
16        super().__init__(output_dim)
17        self.hidden_dims = hidden_dims
18        layers = []
19        input_dim = 1
20        for hidden_dim in self.hidden_dims:
21            layers.extend(
22                [
23                    torch.nn.Conv2d(input_dim, hidden_dim, (3, 3), 1),
24                    torch.nn.ReLU(),
25                    torch.nn.MaxPool2d((2, 2)),
26                ]
27            )
28            input_dim = hidden_dim
29        layers.extend(
30            [
31                torch.nn.AdaptiveAvgPool2d((1, 1)),
32                torch.nn.Flatten(),
33            ]
34        )
35        self.backbone = torch.nn.Sequential(*layers)
36        self.classifier = torch.nn.Linear(self.hidden_dims[-1], output_dim)
37
38    def embeddings(self, x: torch.Tensor) -> torch.Tensor:
39        return self.backbone(x)
40
41    def forward(self, x: torch.Tensor) -> torch.Tensor:
42        return self.classifier(self.embeddings(x))

Next, create a SpectrogramCNN.yaml configuration file for the model in the conf/model/ directory:

conf/model/SpectrogramCNN.yaml#
1id: SpectrogramCNN
2_target_: spectrogram_cnn.SpectrogramCNN
3
4hidden_dims: [32, 64, 128]
5
6transform:
7  type: grayscale

The id should match the name of the configuration file. The _target_ should point to the custom model class via a python import path (here assuming that the spectrogram_cnn.py file is in the root directory of the project). Each model should include a transform/type attribute in the configuration file, specifying the input type it expects.

Note

The output_dim attribute is automatically passed to the model during initialization and determined by the dataset at runtime.

The transform attribute in the configuration is not passed to the model during initialization and is used to specify the input type of the model and any online transforms to be applied to the data at runtime.

Custom Datasets#

To create a custom dataset, inherit from AbstractDataset and implement the target_transform and output_dim properties.

The train, dev, and test datasets as well as loaders are automatically created by the abstract class. However, this requires that the dataset structure follows the standard format outlined in the dataset documentation. If the dataset structure is different or does not rely on dataframes, the load_dataframes(), train_dataset, train_loader etc. methods and properties can be overridden.

autrainer provides base datasets for classification (BaseClassificationDataset), regression (BaseRegressionDataset), and multi-label classification (BaseMLClassificationDataset) tasks. In this case, both the target transform and output dimension are already implemented in the base class and do not need to be overridden.

Tip

To automatically download a custom dataset, implement the download() method. This method is called by the autrainer fetch CLI command as well as the fetch() CLI wrapper function. The path attribute specified in the dataset configuration file is passed to the method to store the downloaded data in.

ESC-50 Example

For example, the ESC-50 dataset is an audio classification dataset and can be implemented as follows:

esc_50.py#
 1import os
 2import shutil
 3from typing import Any, Dict, List, Tuple
 4
 5import pandas as pd
 6
 7from autrainer.datasets import BaseClassificationDataset
 8from autrainer.datasets.utils import ZipDownloadManager
 9
10
11FILES = {"ESC-50.zip": "https://github.com/karoldvl/ESC-50/archive/master.zip"}
12
13
14class ESC50(BaseClassificationDataset):
15    def __init__(
16        self,
17        train_folds: List[int],
18        dev_folds: List[int],
19        test_folds: List[int],
20        **kwargs: Dict[str, Any],  # kwargs only for simplicity in the tutorial
21    ) -> None:
22        self.train_folds = train_folds
23        self.dev_folds = dev_folds
24        self.test_folds = test_folds
25        super().__init__(**kwargs)
26
27    def load_dataframes(
28        self,
29    ) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
30        meta = pd.read_csv(os.path.join(self.path, "esc50.csv"))
31        return (
32            meta[meta["fold"].isin(self.train_folds)],
33            meta[meta["fold"].isin(self.dev_folds)],
34            meta[meta["fold"].isin(self.test_folds)],
35        )
36
37    @staticmethod
38    def download(path: str) -> None:
39        if os.path.exists(os.path.join(path, "default")):
40            return
41
42        dl_manager = ZipDownloadManager(FILES, path)
43        dl_manager.download(check_exist=["ESC-50.zip"])
44        dl_manager.extract(check_exist=["ESC-50-master"])
45        shutil.move(
46            os.path.join(path, "ESC-50-master", "audio"),
47            os.path.join(path, "default"),
48        )
49        shutil.move(
50            os.path.join(path, "ESC-50-master", "meta", "esc50.csv"),
51            path,
52        )
53        shutil.rmtree(os.path.join(path, "ESC-50-master"))

The dataset provides audio files by default (which are moved to the default/ directory in the download() method) and the corresponding metadata of the dataset is stored in the esc50.csv file.

To allow the the specification of custom folds, the load_dataframes() method is overridden to split the esc50.csv file into the respective train, dev, and test dataframes. This also allows for cross-validation by creating multiple configurations with different folds.

To extract log-Mel spectrograms from the audio files, a preprocessing transform can be applied to the data before training. The following configuration creates a new ESC50-32k.yaml dataset in the conf/dataset/ directory with log-Mel spectrograms preprocessed at a sample rate of 32 kHz:

conf/dataset/ESC50-32k.yaml#
 1id: ESC50-32k
 2_target_: esc_50.ESC50
 3
 4path: data/ESC50
 5features_subdir: log_mel_32k
 6index_column: filename
 7target_column: category
 8file_type: npy
 9file_handler: autrainer.datasets.utils.NumpyFileHandler
10
11train_folds: [1, 2, 3]
12dev_folds: [4]
13test_folds: [5]
14
15criterion: autrainer.criterions.BalancedCrossEntropyLoss
16metrics: 
17  - autrainer.metrics.Accuracy
18  - autrainer.metrics.UAR
19  - autrainer.metrics.F1
20tracking_metric: autrainer.metrics.Accuracy
21
22transform:
23  type: grayscale

The dataset can be automatically downloaded and preprocessed using the autrainer fetch and autrainer preprocess CLI commands or the fetch() and preprocess() CLI wrapper functions.

Simple Dataset Example

If the structure of the dataset follows the standard format outlined in the dataset documentation, no implementation is necessary and a new dataset can be created by simply adding a configuration file to the conf/dataset/ directory.

For example, the following configuration file creates a new SpectrogramDataset.yaml classification dataset, preprocessing the data with a spectrogram preprocessing transform at a sample rate of 32 kHz:

conf/dataset/SpectrogramDataset.yaml#
 1id: SpectrogramDataset-32k
 2_target_: autrainer.datasets.BaseClassificationDataset
 3
 4path: data/SpectrogramDataset # base path to the dataset
 5features_subdir: log_mel_32k # spectrogram preprocessed features
 6index_column: path # column in the CSVs containing features paths relative to features_subdir
 7target_column: label # column in the CSVs containing the target labels
 8file_type: npy # file extension of the spectrogram features
 9file_handler: autrainer.datasets.utils.NumpyFileHandler # file handler for the spectrogram features
10
11criterion: autrainer.criterions.BalancedCrossEntropyLoss
12metrics: 
13  - autrainer.metrics.Accuracy
14  - autrainer.metrics.UAR
15  - autrainer.metrics.F1
16tracking_metric: autrainer.metrics.Accuracy
17
18transform:
19  type: grayscale

This dataset assumes that the data/SpectrogramDataset directory contains the following directories and files:

  • default/ directory containing the raw audio files. These audio files are preprocessed using the spectrogram preprocessing transform with the autrainer preprocess CLI command or the preprocess() CLI wrapper function and stored in the data/SpectrogramDataset/log_mel_32k directory.

  • train.csv, dev.csv, and test.csv files containing the file paths relative to the default/ directory in the index_column column and the corresponding labels in the target_column column.

Custom Metrics#

To create a custom metric, inherit from AbstractMetric and implement the starting_metric, suffix properties, as well as the get_best(), the get_best_pos(), and compare() static methods.

autrainer provides base classes for ascending (BaseAscendingMetric) and descending (BaseDescendingMetric) metrics that can be inherited from to simplify the implementation.

For example, the following metric implements the Cohen’s Kappa score with either linear or quadratic weights:

cohens_kappa_metric.py#
 1import sklearn.metrics
 2
 3from autrainer.metrics import BaseAscendingMetric
 4
 5
 6class CohensKappa(BaseAscendingMetric):
 7    def __init__(self, weights: str) -> None:
 8        """Coehn's Kappa metric using `sklearn.metrics.cohen_kappa_score`.
 9
10        Args:
11            weights: Weighting type for the metric in ["linear", "quadratic"].
12        """
13        super().__init__(
14            name="cohens-kappa",
15            fn=sklearn.metrics.cohen_kappa_score,
16            weights=weights,
17        )

The fn attribute is the function that is automatically called in the __call__() method and the weights attribute is passed to the fn as a keyword argument.

As metrics are specified using shorthand syntax in the dataset configuration, the following relative import path can be used to reference it as the tracking_metric for the dataset:

conf/dataset/ExampleDataset.yaml#
1...
2tracking_metric:
3  cohens_kappa_metric.CohensKappa:
4     weights: linear # linear or quadratic
5...

Custom Criterions#

To create a custom criterion, inherit from torch.nn.modules.loss._Loss and implement the forward() method. If the criterion relies on the dataset, an optional criterion setup method can be defined which is called after the dataset is initialized.

For example, the following criterion implements CrossEntropyLoss with an additional scaling factor:

scaled_ce_loss.py#
 1import torch
 2
 3
 4class ScaledCrossEntropyLoss(torch.nn.CrossEntropyLoss):
 5    def __init__(self, scaling_factor: float = 1.0, *args, **kwargs):
 6        """Cross entropy loss with a scaling factor.
 7
 8        Args:
 9            scaling_factor: Scaling factor for the loss.
10            *args: Positional arguments passed to `torch.nn.CrossEntropyLoss`.
11            **kwargs: Keyword arguments passed to `torch.nn.CrossEntropyLoss`.
12        """
13        super().__init__(*args, **kwargs)
14        self.scaling_factor = scaling_factor
15
16    def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
17        if y.ndim == 1:
18            y = y.long()
19        return self.scaling_factor * super().forward(x, y)

As criterions are specified using shorthand syntax in the dataset configuration, the following relative import path can be used to reference it as the criterion for the dataset:

conf/dataset/ExampleDataset.yaml#
1...
2criterion:
3  scaled_ce_loss.ScaledCrossEntropyLoss:
4     scaling_factor: 0.5
5...

Or without overriding the default scaling_factor value:

conf/dataset/ExampleDataset.yaml#
1...
2criterion: scaled_ce_loss.ScaledCrossEntropyLoss
3...

Custom File Handlers#

To create a custom file handler, inherit from AbstractFileHandler and implement the load() and save() methods.

For example, the following file handler loads and saves PyTorch tensors:

torch_file_handler.py#
 1import torch
 2
 3from autrainer.datasets.utils import AbstractFileHandler
 4
 5
 6class TorchFileHandler(AbstractFileHandler):
 7    def load(self, file: str) -> torch.Tensor:
 8        return torch.load(file)
 9
10    def save(self, file: str, data: torch.Tensor) -> None:
11        torch.save(data, file)

File handlers are specified using shorthand syntax in the dataset configuration. The following configuration utilizes the TorchFileHandler to load and save PyTorch tensors with the file extension .pt:

conf/dataset/ExampleDataset.yaml#
1...
2file_type: pt
3file_handler: torch_file_handler.TorchFileHandler
4...

Custom Target Transforms#

To create a custom target transform, inherit from AbstractTargetTransform and implement the encode(), decode(), predict_batch(), and majority_vote() methods.

For example, the following target transform logarithmically encodes and decodes the targets for regression tasks:

log_target_transform.py#
 1import math
 2from typing import Dict, List, Union
 3
 4import torch
 5
 6from autrainer.datasets.utils import AbstractTargetTransform
 7
 8
 9class LogTargetTransform(AbstractTargetTransform):
10    def __init__(self, target: str, base: int = 10, eps: float = 1e-9) -> None:
11        """Logarithmic target transform for regression tasks.
12
13        Args:
14            target: Name of the target.
15            base: Base of the logarithm. Defaults to 10.
16            eps: Small value to avoid taking the logarithm of zero.
17                Defaults to 1e-9.
18        """
19        self.target = target
20        self.base = base
21        self.eps = eps
22
23    def encode(self, x: float) -> float:
24        return math.log(x + self.eps, self.base)
25
26    def decode(self, x: float) -> float:
27        return math.pow(self.base, x) - self.eps
28
29    def probabilities_inference(self, x: torch.Tensor) -> torch.Tensor:
30        return x
31
32    def predict_inference(self, x: torch.Tensor) -> Union[List[float], float]:
33        return x.squeeze().tolist()
34
35    def majority_vote(self, x: List[float]) -> float:
36        return sum(x) / len(x)
37
38    def probabilities_to_dict(self, x: torch.Tensor) -> Dict[str, float]:
39        return {self.target: x.item()}

The target transforms are specified in the target_transform property of a dataset implementation.

Custom Optimizers#

To create a custom optimizer, inherit from torch.optim.Optimizer and implement the step() method.

For example, the following optimizer implements the SGD optimizer with an additional randomly scaled learning rate using a custom step function:

random_scaled_sgd.py#
 1from typing import Callable, Tuple
 2
 3import torch
 4
 5
 6class RandomScaledSGD(torch.optim.Optimizer):
 7    def __init__(
 8        self,
 9        scaling_factor: float = 0.01,
10        p: float = 1.0,
11        generator_seed: int = None,
12        *args,
13        **kwargs,
14    ) -> None:
15        """Randomized Scaled SGD optimizer. Randomly scales the learning rate.
16
17        Args:
18            scaling_factor: Learning rate scaling factor. Defaults to 1.0.
19            p: Probability of scaling the learning rate. Defaults to 1.0.
20            generator_seed: Seed for the random number generator.
21                Defaults to None.
22        """
23        super().__init__(*args, **kwargs)
24        self.scaling_factor = scaling_factor
25        self.p = p
26        self.g = torch.Generator()
27        self.base_lr = self.param_groups[0]["lr"]
28        if generator_seed:
29            self.g.manual_seed(generator_seed)
30
31    def custom_step(
32        self,
33        model: torch.nn.Module,  # model
34        data: torch.Tensor,  # batched input data
35        target: torch.Tensor,  # batched target data
36        criterion: torch.nn.Module,  # loss function
37        probabilities_fn: Callable,  # function to get probabilities from model outputs
38    ) -> Tuple[float, torch.Tensor]:
39        self.zero_grad()
40        output = model(data)
41        loss = criterion(probabilities_fn(output), target)
42        loss.backward()
43        if torch.rand(1, generator=self.g).item() < self.p:
44            self.param_groups[0]["lr"] *= self.scaling_factor
45        self.step()
46        self.param_groups[0]["lr"] = self.base_lr
47        return loss.item(), output.detach()

The following configuration creates a new RandomScaledSGD.yaml optimizer in the conf/optimizer/ directory and uses the global seed of the main configuration as the generator_seed attribute:

conf/optimizer/RandomScaledSGD.yaml#
1id: RandomScaledSGD
2_target_: random_scaled_sgd.RandomScaledSGD
3
4scaling_factor: 0.001
5p: 0.05
6generator_seed: ${seed}

Note

The params and lr attributes are automatically passed to the optimizer during initialization and determined at runtime.

Custom Schedulers#

To create a custom scheduler, inherit from torch.optim.lr_scheduler.LRScheduler and implement the get_lr() method.

For example, the following scheduler implements a simple linear warm-up scheduler:

linear_warm_up_lr.py#
 1from typing import List
 2
 3import torch
 4from torch.optim.lr_scheduler import LRScheduler
 5
 6
 7class LinearWarmUpLR(LRScheduler):
 8    def __init__(
 9        self,
10        optimizer: torch.optim.Optimizer,
11        warmup_steps: int,
12        last_epoch: int = -1,
13    ) -> None:
14        """Linear warm-up learning rate scheduler.
15
16        Args:
17            optimizer: Wrapped optimizer.
18            warmup_steps: Number of warmup steps.
19            last_epoch: The index of last epoch. Defaults to -1.
20        """
21        self.warmup_steps = warmup_steps
22        super().__init__(optimizer, last_epoch)
23
24    def get_lr(self) -> List[float]:
25        if self.last_epoch < self.warmup_steps:
26            return [
27                base_lr * (self.last_epoch + 1) / self.warmup_steps
28                for base_lr in self.base_lrs
29            ]
30        return self.base_lrs

The following configuration creates a new LinearWarmUpLR.yaml scheduler with a linear warm-up period of 10 training iterations in the conf/scheduler/ directory:

conf/scheduler/LinearWarmUpLR.yaml#
1id: LinearWarmUpLR
2_target_: linear_warm_up_lr.LinearWarmUpLR
3
4warmup_steps: 10
5
6step_frequency: evaluation

Note

The optimizer attribute is automatically passed to the scheduler during initialization and determined at runtime.

Custom Transforms#

To create a custom transform, inherit from AbstractTransform and implement the __call__() method.

For example, the following transform denoises a spectrogram by applying a median filter:

spect_median_filter.py#
 1import scipy.ndimage
 2import torch
 3
 4from autrainer.transforms import AbstractTransform
 5
 6
 7class SpectMedianFilter(AbstractTransform):
 8    def __init__(self, size: int, order: int = 0) -> None:
 9        """Spectrogram median filter to remove noise.
10
11        Args:
12            size: Number of neighboring pixels to consider when filtering.
13                Must be odd.
14            order: The order of the transform in the pipeline. Larger means
15                later in the pipeline. If multiple transforms have the same
16                order, they are applied in the order they were added to the
17                pipeline. Defaults to 0.
18        """
19        super().__init__(order=order)
20        self.size = size
21
22    def __call__(self, x: torch.Tensor) -> torch.Tensor:
23        return torch.from_numpy(
24            scipy.ndimage.median_filter(
25                x.cpu().numpy(),
26                size=self.size,
27            )
28        ).to(x.device)

This transform can be used both as a preprocessing transform and as an online transform.

Custom Preprocessing Transforms#

To create a custom preprocessing transform, create a new file in the conf/preprocessing/ directory.

For example, the following preprocessing transform extracts log-Mel spectrograms from audio data at a sampling rate of 32 kHz and applies the custom denoising transform to the data:

conf/scheduler/denoised_log_mel_32k.yaml#
 1file_handler:
 2  autrainer.datasets.utils.AudioFileHandler:
 3    target_sample_rate: 32000
 4pipeline:
 5  - autrainer.transforms.StereoToMono
 6  - autrainer.transforms.PannMel:
 7      sample_rate: 32000
 8      window_size: 1024
 9      hop_size: 320
10      mel_bins: 64
11      fmin: 50
12      fmax: 14000
13      ref: 1.0
14      amin: 1e-10
15      top_db: null
16  - spect_median_filter.SpectMedianFilter:
17      size: 5

Any audio dataset can utilize this preprocessing transform by specifying the features_subdir attribute in the dataset configuration and adjusting the file_type, file_handler, and transform attributes:

conf/dataset/ExampleDataset.yaml#
1...
2features_subdir: denoised_log_mel_32k
3file_type: npy
4file_handler: autrainer.datasets.utils.NumpyFileHandler
5...
6transform:
7  type: grayscale

Note

The save() method of the file_handler specified in the dataset configuration is used to save the processed data to the features_subdir directory. The load() method of the file_handler is used to load the processed data during training and inference.

Custom Online Transforms#

To create a custom online transform, no configuration file is necessary as the transform is applied at runtime and specified in the transform attribute of the model and dataset configurations using shorthand syntax.

For example, the following configuration applies the custom denoising transform to the data at runtime:

conf/dataset/ExampleDataset.yaml#
1...
2transform:
3  type: grayscale
4  base:
5    - spect_median_filter.SpectMedianFilter:
6        size: 5

In line with the custom preprocessing transform example, the custom denoising transform is applied to the train, dev, and test datasets.

It may be desirable to only apply a transform to a specific subset of the data. The following configuration applies the custom denoising transform only to the train subset of the data:

conf/dataset/ExampleDataset.yaml#
1...
2transform:
3  type: grayscale
4  train:
5    - spect_median_filter.SpectMedianFilter:
6        size: 5

Custom Augmentations#

To create a custom augmentation, inherit from AbstractAugmentation and implement the apply() method.

For example, the following augmentation scales the amplitude of a spectrogram by a random factor in a given range:

amplitude_scale_augmentation.py#
 1from typing import Optional, Tuple
 2
 3import torch
 4
 5from autrainer.augmentations import AbstractAugmentation
 6
 7
 8class AmplitudeScale(AbstractAugmentation):
 9    def __init__(
10        self,
11        scale_range: Tuple[float, float],
12        order: int = 0,
13        p: float = 1.0,
14        generator_seed: Optional[int] = None,
15    ) -> None:
16        """Amplitude scaling augmentation. The amplitude is randomly scaled by
17        a factor drawn from scale_range.
18
19        Args:
20            scale_range: The range of the amplitude scaling factor.
21            order: The order of the augmentation in the transformation pipeline.
22                Defaults to 0.
23            p: The probability of applying the augmentation. Defaults to 1.0.
24            generator_seed: The initial seed for the internal random number
25                generator drawing the probability. If None, the generator is
26                not seeded. Defaults to None.
27
28        Raises:
29            ValueError: If p is not in the range [0, 1].
30        """
31        super().__init__(order, p, generator_seed)
32        self.scale_range = scale_range
33        self.scale_g = torch.Generator()
34        if self.generator_seed:
35            self.scale_g.manual_seed(self.generator_seed)
36
37    def apply(self, x: torch.Tensor, index: int = None) -> torch.Tensor:
38        s0, s1 = self.scale_range
39        return x * (torch.rand(1, generator=self.scale_g) * (s1 - s0) + s0)

The following configuration creates a new AmplitudeScale.yaml augmentation in the conf/augmentation/ directory, scaling the amplitude of the spectrogram by a random factor between 0.8 and 1.2 with a probability p of 0.5:

conf/augmentation/AmplitudeScale.yaml#
1id: AmplitudeScale
2_target_: autrainer.augmentations.AugmentationPipeline
3
4generator_seed: 0
5
6pipeline:
7  - amplitude_scale_augmentation.AmplitudeScale:
8      scale_range: [0.8, 1.2]
9      p: 0.5

As no augmentation in the pipeline specifies a generator_seed attribute, the global generator_seed attribute is broadcasted to all augmentations to ensure reproducibility.

Custom Augmentation Graphs#

For example, the following configuration creates a new AmplitudeScaleOrTimeFreqMask.yaml augmentation in the conf/augmentation/ directory, either applying the custom amplitude scale augmentation or a sequence of the TimeMask and FrequencyMask augmentations:

conf/augmentation/AmplitudeScaleOrTimeFreqMask.yaml#
 1id: AmplitudeScaleOrTimeFreqMask
 2_target_: autrainer.augmentations.AugmentationPipeline
 3
 4generator_seed: 0
 5
 6pipeline:
 7  - autrainer.augmentations.Choice:
 8      weights: [0.2, 0.8]
 9      choices:
10        - amplitude_scale_augmentation.AmplitudeScale:
11            scale_range: [0.8, 1.2]
12        - autrainer.augmentations.Sequential:
13            sequence:
14              - autrainer.augmentations.TimeMask:
15                  time_mask: 80
16              - autrainer.augmentations.FrequencyMask:
17                  freq_mask: 10

The custom amplitude scale augmentation is selected with a probability of 0.2, while the sequence of the TimeMask and FrequencyMask augmentations is selected with a probability of 0.8.

Custom Collate Augmentations#

To create a custom collate augmentation, inherit from AbstractAugmentation and implement the optional get_collate_fn() method.

The collate function is used to apply the augmentation on the batch level. In case the collate function modifies the shape of the input or labels, this may need to be accounted for if the augmentation is not applied.

For example, the following augmentation randomly applies CutMix or MixUp augmentations on the batch level:

cut_mix_up.py#
 1from typing import TYPE_CHECKING, Callable, List, Optional, Tuple
 2
 3import torch
 4from torch.utils.data import default_collate
 5from torchvision.transforms import v2
 6
 7from autrainer.augmentations import AbstractAugmentation
 8
 9
10if TYPE_CHECKING:
11    from autrainer.datasets import AbstractDataset
12
13
14class CutMixUp(AbstractAugmentation):
15    def __init__(
16        self,
17        alpha: float = 1.0,
18        order: int = 0,
19        p: float = 1.0,
20        generator_seed: Optional[int] = None,
21    ) -> None:
22        """Randomly applies CutMix or MixUp augmentations with a probability
23        of 0.5 each.
24
25        Args:
26            alpha: Hyperparameter of the Beta distribution. Defaults to 1.0.
27            order: The order of the augmentation in the transformation pipeline.
28                Defaults to 0.
29            p: The probability of applying the augmentation. Defaults to 1.0.
30            generator_seed: The initial seed for the internal random number
31                generator drawing the probability. If None, the generator is
32                not seeded. Defaults to None.
33        """
34        super().__init__(order, p, generator_seed)
35        self.alpha = alpha
36        self.cut_mix_up_g = torch.Generator()
37        if generator_seed:
38            self.cut_mix_up_g.manual_seed(generator_seed)
39
40    def get_collate_fn(self, data: "AbstractDataset") -> Callable:
41        self.cutmix = v2.CutMix(num_classes=data.output_dim, alpha=self.alpha)
42        self.mixup = v2.MixUp(num_classes=data.output_dim, alpha=self.alpha)
43
44        def _collate_fn(
45            batch: List[Tuple[torch.Tensor, int, int]],
46        ) -> List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
47            probability = torch.rand(1, generator=self.g).item()
48            if probability < self.p:
49                p = torch.rand(1, generator=self.cut_mix_up_g).item()
50                if p < 0.5:
51                    return self.cutmix(*default_collate(batch))
52                return self.mixup(*default_collate(batch))
53
54            # still one-hot encode the labels if no augmentation is applied
55            batched = default_collate(batch)
56            batched[1] = torch.nn.functional.one_hot(
57                batched[1], data.output_dim
58            ).float()
59            return batched
60
61        return _collate_fn
62
63    def apply(self, x: torch.Tensor, index: int = None) -> torch.Tensor:
64        # no-op as the augmentation is applied in the collate function
65        return x

Custom Loggers#

To create a custom logger, inherit from AbstractLogger and implement the log_params(), log_metrics(), log_timers(), and log_artifact() methods, as well as the optional setup(), and end_run() methods.

All methods are automatically called a the appropriate time during training and inference.

For example, the following logger logs to Weights & Biases:

wandb_logger.py#
 1import os
 2from typing import Dict, List, Union
 3
 4from omegaconf import DictConfig
 5import wandb
 6
 7from autrainer.core.constants import ExportConstants
 8from autrainer.loggers import (
 9    AbstractLogger,
10    get_params_to_export,
11)
12from autrainer.metrics import AbstractMetric
13
14
15class WandBLogger(AbstractLogger):
16    def __init__(
17        self,
18        exp_name: str,
19        run_name: str,
20        metrics: List[AbstractMetric],
21        tracking_metric: AbstractMetric,
22        artifacts: List[
23            Union[str, Dict[str, str]]
24        ] = ExportConstants().ARTIFACTS,
25        output_dir: str = "wandb",
26    ) -> None:
27        super().__init__(
28            exp_name, run_name, metrics, tracking_metric, artifacts
29        )
30        if not os.path.isabs(output_dir):
31            output_dir = os.path.join(os.getcwd(), output_dir)
32        os.makedirs(output_dir, exist_ok=True)
33        self.output_dir = output_dir
34
35    def log_params(self, params: Union[dict, DictConfig]) -> None:
36        wandb.init(
37            project=self.exp_name,
38            name=self.run_name,
39            config=get_params_to_export(params),
40            dir=self.output_dir,
41        )
42
43    def log_metrics(
44        self,
45        metrics: Dict[str, Union[int, float]],
46        iteration=None,
47    ) -> None:
48        wandb.log(metrics, step=iteration)
49
50    def log_timers(self, timers: Dict[str, float]) -> None:
51        wandb.log(timers)
52
53    def log_artifact(self, filename: str, path: str = "") -> None:
54        artifact = wandb.Artifact(name=filename, type="model")
55        artifact.add_file(os.path.join(path, filename))
56        wandb.log_artifact(artifact)
57
58    def end_run(self) -> None:
59        wandb.finish()

Note that the WandBLogger assumes that wandb is installed, the API key is set, and a project with the same name as the experiment_id of the main configuration exists.

To add the WandBLogger, specify it in the main configuration by adding a list of loggers:

conf/config.yaml#
1...
2loggers:
3  - wandb_logger.WandBLogger:
4      output_dir: ${results_dir}/.wandb
5...

Custom Callbacks#

To create a custom callback, implement a class that specifies any of the callback functions defined in CallbackSignature.

For example, the following callback tracks learning rate changes at the beginning of each iteration:

lr_tracker_callback.py#
 1from typing import TYPE_CHECKING
 2
 3
 4if TYPE_CHECKING:
 5    from autrainer.training import ModularTaskTrainer
 6
 7
 8class LRTrackerCallback:
 9    def cb_on_train_begin(self, trainer: "ModularTaskTrainer") -> None:
10        self.lr = trainer.optimizer.param_groups[0]["lr"]
11
12    def cb_on_iteration_begin(
13        self,
14        trainer: "ModularTaskTrainer",
15        iteration: int,
16    ) -> None:
17        current_lr = trainer.optimizer.param_groups[0]["lr"]
18        if current_lr != self.lr:
19            print(
20                f"Learning rate changed from {self.lr} "
21                f"to {current_lr} in iteration {iteration}."
22            )
23            self.lr = current_lr

To add the LRTrackerCallback, specify it in the main configuration by adding a list of callbacks:

conf/config.yaml#
1...
2callbacks:
3  - lr_tracker_callback.LRTrackerCallback
4...

Custom Plotting#

To create a custom plotting configuration, create a new file in the conf/plotting/ directory.

For example, the following configuration uses the LaTeX backend, the Palatino font with a font size of 9, replaces None values in the run name with for better readability, and adds labels as well as titles to the plot.

conf/plotting/LaTeX.yaml#
 1figsize: [10, 5] # figure size in inches
 2latex: true # use LaTeX for text rendering
 3filetypes: [png, pdf] # save figures in these formats
 4pickle: true # save the figure data in a pickle file
 5context: notebook # seaborn context
 6palette: colorblind # seaborn color palette
 7replace_none: true # replace None with ~
 8add_titles: true
 9add_xlabels: true
10add_ylabels: true
11
12rcParams:
13  font.serif: Palatino # LaTeX font
14  font.family: serif
15  legend.fontsize: 9

To add the LaTeX.yaml plotting configuration, specify it in the main configuration by overriding the plotting attribute:

conf/config.yaml#
1defaults:
2  - ...
3  - override plotting: LaTeX
4...