Tutorials#
autrainer is designed to be flexible and extensible, allowing for the creation of custom …
datasets (including metrics, criterions, file handlers, and target transforms)
transforms (including preprocessing transforms and online transforms)
For each, a tutorial is provided below to demonstrate their implementation and configuration.
For the following tutorials, all python files should be placed in the project root directory
and all configuration files should be placed in the corresponding subdirectories of the conf/
directory.
Custom Models#
To create a custom model, inherit from AbstractModel
and implement the forward()
and embeddings()
methods.
All arguments of the constructor have to be assigned to a variable with the same name, as AbstractModel
inherits from audobject.
For example, the following model is a simple CNN that takes a spectrogram as input and has a variable number of hidden CNN layers with a different number of filters each:
1from typing import List
2
3import torch
4
5from autrainer.models import AbstractModel
6
7
8class SpectrogramCNN(AbstractModel):
9 def __init__(self, output_dim: int, hidden_dims: List[int]) -> None:
10 """Spectrogram CNN model with a variable number of hidden CNN layers.
11
12 Args:
13 output_dim: Output dimension of the model.
14 hidden_dims: List of hidden dimensions for the CNN layers.
15 """
16 super().__init__(output_dim)
17 self.hidden_dims = hidden_dims
18 layers = []
19 input_dim = 1
20 for hidden_dim in self.hidden_dims:
21 layers.extend(
22 [
23 torch.nn.Conv2d(input_dim, hidden_dim, (3, 3), 1),
24 torch.nn.ReLU(),
25 torch.nn.MaxPool2d((2, 2)),
26 ]
27 )
28 input_dim = hidden_dim
29 layers.extend(
30 [
31 torch.nn.AdaptiveAvgPool2d((1, 1)),
32 torch.nn.Flatten(),
33 ]
34 )
35 self.backbone = torch.nn.Sequential(*layers)
36 self.classifier = torch.nn.Linear(self.hidden_dims[-1], output_dim)
37
38 def embeddings(self, x: torch.Tensor) -> torch.Tensor:
39 return self.backbone(x)
40
41 def forward(self, x: torch.Tensor) -> torch.Tensor:
42 return self.classifier(self.embeddings(x))
Next, create a SpectrogramCNN.yaml
configuration file for the model in the conf/model/
directory:
1id: SpectrogramCNN
2_target_: spectrogram_cnn.SpectrogramCNN
3
4hidden_dims: [32, 64, 128]
5
6transform:
7 type: grayscale
The id
should match the name of the configuration file.
The _target_
should point to the custom model class via a python import path
(here assuming that the spectrogram_cnn.py
file is in the root directory of the project).
Each model should include a transform/type
attribute in the configuration file,
specifying the input type it expects.
Note
The output_dim
attribute is automatically passed to the model during initialization
and determined by the dataset at runtime.
The transform
attribute in the configuration is not passed to the model during initialization
and is used to specify the input type of the model and any
online transforms to be applied to the data at runtime.
Custom Datasets#
To create a custom dataset, inherit from AbstractDataset
and implement the target_transform
and output_dim
properties.
The train, dev, and test datasets as well as loaders are automatically created by the abstract class.
However, this requires that the dataset structure follows the standard format outlined in the dataset documentation.
If the dataset structure is different or does not rely on dataframes, the load_dataframes()
,
train_dataset
, train_loader
etc.
methods and properties can be overridden.
autrainer provides base datasets for classification (BaseClassificationDataset
),
regression (BaseRegressionDataset
),
and multi-label classification (BaseMLClassificationDataset
) tasks.
In this case, both the target transform and output dimension are already implemented in the base class and do not need to be overridden.
Tip
To automatically download a custom dataset, implement the download()
method.
This method is called by the autrainer fetch CLI command as well as the fetch()
CLI wrapper function.
The path
attribute specified in the dataset configuration file is passed to the method to store the downloaded data in.
ESC-50 Example
For example, the ESC-50 dataset is an audio classification dataset and can be implemented as follows:
1import os
2import shutil
3from typing import Any, Dict, List, Tuple
4
5import pandas as pd
6
7from autrainer.datasets import BaseClassificationDataset
8from autrainer.datasets.utils import ZipDownloadManager
9
10
11FILES = {"ESC-50.zip": "https://github.com/karoldvl/ESC-50/archive/master.zip"}
12
13
14class ESC50(BaseClassificationDataset):
15 def __init__(
16 self,
17 train_folds: List[int],
18 dev_folds: List[int],
19 test_folds: List[int],
20 **kwargs: Dict[str, Any], # kwargs only for simplicity in the tutorial
21 ) -> None:
22 self.train_folds = train_folds
23 self.dev_folds = dev_folds
24 self.test_folds = test_folds
25 super().__init__(**kwargs)
26
27 def load_dataframes(
28 self,
29 ) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
30 meta = pd.read_csv(os.path.join(self.path, "esc50.csv"))
31 return (
32 meta[meta["fold"].isin(self.train_folds)],
33 meta[meta["fold"].isin(self.dev_folds)],
34 meta[meta["fold"].isin(self.test_folds)],
35 )
36
37 @staticmethod
38 def download(path: str) -> None:
39 if os.path.exists(os.path.join(path, "default")):
40 return
41
42 dl_manager = ZipDownloadManager(FILES, path)
43 dl_manager.download(check_exist=["ESC-50.zip"])
44 dl_manager.extract(check_exist=["ESC-50-master"])
45 shutil.move(
46 os.path.join(path, "ESC-50-master", "audio"),
47 os.path.join(path, "default"),
48 )
49 shutil.move(
50 os.path.join(path, "ESC-50-master", "meta", "esc50.csv"),
51 path,
52 )
53 shutil.rmtree(os.path.join(path, "ESC-50-master"))
The dataset provides audio files by default (which are moved to the default/
directory in the
download()
method) and the corresponding metadata of the dataset is stored in the esc50.csv
file.
To allow the the specification of custom folds, the load_dataframes()
method is overridden to split the esc50.csv
file into the respective train, dev, and test dataframes.
This also allows for cross-validation by creating multiple configurations with different folds.
To extract log-Mel spectrograms from the audio files, a preprocessing transform
can be applied to the data before training.
The following configuration creates a new ESC50-32k.yaml
dataset in the conf/dataset/
directory with log-Mel spectrograms
preprocessed at a sample rate of 32 kHz:
1id: ESC50-32k
2_target_: esc_50.ESC50
3
4path: data/ESC50
5features_subdir: log_mel_32k
6index_column: filename
7target_column: category
8file_type: npy
9file_handler: autrainer.datasets.utils.NumpyFileHandler
10
11train_folds: [1, 2, 3]
12dev_folds: [4]
13test_folds: [5]
14
15criterion: autrainer.criterions.BalancedCrossEntropyLoss
16metrics:
17 - autrainer.metrics.Accuracy
18 - autrainer.metrics.UAR
19 - autrainer.metrics.F1
20tracking_metric: autrainer.metrics.Accuracy
21
22transform:
23 type: grayscale
The dataset can be automatically downloaded and preprocessed using the autrainer fetch
and autrainer preprocess CLI commands
or the fetch()
and preprocess()
CLI wrapper functions.
Simple Dataset Example
If the structure of the dataset follows the standard format outlined in the dataset documentation,
no implementation is necessary and a new dataset can be created by simply adding a configuration file to the conf/dataset/
directory.
For example, the following configuration file creates a new SpectrogramDataset.yaml
classification dataset,
preprocessing the data with a spectrogram preprocessing transform at a sample rate of 32 kHz:
1id: SpectrogramDataset-32k
2_target_: autrainer.datasets.BaseClassificationDataset
3
4path: data/SpectrogramDataset # base path to the dataset
5features_subdir: log_mel_32k # spectrogram preprocessed features
6index_column: path # column in the CSVs containing features paths relative to features_subdir
7target_column: label # column in the CSVs containing the target labels
8file_type: npy # file extension of the spectrogram features
9file_handler: autrainer.datasets.utils.NumpyFileHandler # file handler for the spectrogram features
10
11criterion: autrainer.criterions.BalancedCrossEntropyLoss
12metrics:
13 - autrainer.metrics.Accuracy
14 - autrainer.metrics.UAR
15 - autrainer.metrics.F1
16tracking_metric: autrainer.metrics.Accuracy
17
18transform:
19 type: grayscale
This dataset assumes that the data/SpectrogramDataset
directory contains the following directories and files:
default/
directory containing the raw audio files. These audio files are preprocessed using the spectrogram preprocessing transform with the autrainer preprocess CLI command or thepreprocess()
CLI wrapper function and stored in thedata/SpectrogramDataset/log_mel_32k
directory.train.csv
,dev.csv
, andtest.csv
files containing the file paths relative to thedefault/
directory in theindex_column
column and the corresponding labels in thetarget_column
column.
Custom Metrics#
To create a custom metric, inherit from AbstractMetric
and implement the starting_metric
,
suffix
properties,
as well as the get_best()
,
the get_best_pos()
,
and compare()
static methods.
autrainer provides base classes for ascending (BaseAscendingMetric
) and descending
(BaseDescendingMetric
) metrics that can be inherited from to simplify the implementation.
For example, the following metric implements the Cohen’s Kappa score with either linear or quadratic weights:
1import sklearn.metrics
2
3from autrainer.metrics import BaseAscendingMetric
4
5
6class CohensKappa(BaseAscendingMetric):
7 def __init__(self, weights: str) -> None:
8 """Coehn's Kappa metric using `sklearn.metrics.cohen_kappa_score`.
9
10 Args:
11 weights: Weighting type for the metric in ["linear", "quadratic"].
12 """
13 super().__init__(
14 name="cohens-kappa",
15 fn=sklearn.metrics.cohen_kappa_score,
16 weights=weights,
17 )
The fn
attribute is the function that is automatically called in the __call__()
method
and the weights
attribute is passed to the fn
as a keyword argument.
As metrics are specified using shorthand syntax in the dataset configuration,
the following relative import path can be used to reference it as the tracking_metric
for the dataset:
1...
2tracking_metric:
3 cohens_kappa_metric.CohensKappa:
4 weights: linear # linear or quadratic
5...
Custom Criterions#
To create a custom criterion, inherit from torch.nn.modules.loss._Loss
and implement the forward()
method.
If the criterion relies on the dataset, an optional criterion setup method
can be defined which is called after the dataset is initialized.
For example, the following criterion implements CrossEntropyLoss with an additional scaling factor:
1import torch
2
3
4class ScaledCrossEntropyLoss(torch.nn.CrossEntropyLoss):
5 def __init__(self, scaling_factor: float = 1.0, *args, **kwargs):
6 """Cross entropy loss with a scaling factor.
7
8 Args:
9 scaling_factor: Scaling factor for the loss.
10 *args: Positional arguments passed to `torch.nn.CrossEntropyLoss`.
11 **kwargs: Keyword arguments passed to `torch.nn.CrossEntropyLoss`.
12 """
13 super().__init__(*args, **kwargs)
14 self.scaling_factor = scaling_factor
15
16 def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
17 if y.ndim == 1:
18 y = y.long()
19 return self.scaling_factor * super().forward(x, y)
As criterions are specified using shorthand syntax in the dataset configuration,
the following relative import path can be used to reference it as the criterion
for the dataset:
1...
2criterion:
3 scaled_ce_loss.ScaledCrossEntropyLoss:
4 scaling_factor: 0.5
5...
Or without overriding the default scaling_factor
value:
1...
2criterion: scaled_ce_loss.ScaledCrossEntropyLoss
3...
Custom File Handlers#
To create a custom file handler, inherit from AbstractFileHandler
and implement the load()
and save()
methods.
For example, the following file handler loads and saves PyTorch tensors:
1import torch
2
3from autrainer.datasets.utils import AbstractFileHandler
4
5
6class TorchFileHandler(AbstractFileHandler):
7 def load(self, file: str) -> torch.Tensor:
8 return torch.load(file)
9
10 def save(self, file: str, data: torch.Tensor) -> None:
11 torch.save(data, file)
File handlers are specified using shorthand syntax in the dataset configuration.
The following configuration utilizes the TorchFileHandler
to load and save PyTorch tensors with the file extension .pt
:
1...
2file_type: pt
3file_handler: torch_file_handler.TorchFileHandler
4...
Custom Target Transforms#
To create a custom target transform, inherit from AbstractTargetTransform
and implement the encode()
, decode()
,
predict_batch()
,
and majority_vote()
methods.
For example, the following target transform logarithmically encodes and decodes the targets for regression tasks:
1import math
2from typing import Dict, List, Union
3
4import torch
5
6from autrainer.datasets.utils import AbstractTargetTransform
7
8
9class LogTargetTransform(AbstractTargetTransform):
10 def __init__(self, target: str, base: int = 10, eps: float = 1e-9) -> None:
11 """Logarithmic target transform for regression tasks.
12
13 Args:
14 target: Name of the target.
15 base: Base of the logarithm. Defaults to 10.
16 eps: Small value to avoid taking the logarithm of zero.
17 Defaults to 1e-9.
18 """
19 self.target = target
20 self.base = base
21 self.eps = eps
22
23 def encode(self, x: float) -> float:
24 return math.log(x + self.eps, self.base)
25
26 def decode(self, x: float) -> float:
27 return math.pow(self.base, x) - self.eps
28
29 def probabilities_inference(self, x: torch.Tensor) -> torch.Tensor:
30 return x
31
32 def predict_inference(self, x: torch.Tensor) -> Union[List[float], float]:
33 return x.squeeze().tolist()
34
35 def majority_vote(self, x: List[float]) -> float:
36 return sum(x) / len(x)
37
38 def probabilities_to_dict(self, x: torch.Tensor) -> Dict[str, float]:
39 return {self.target: x.item()}
The target transforms are specified in the target_transform
property
of a dataset implementation.
Custom Optimizers#
To create a custom optimizer, inherit from
torch.optim.Optimizer and implement the step()
method.
For example, the following optimizer implements the SGD optimizer with an additional randomly scaled learning rate using a custom step function:
1from typing import Callable, Tuple
2
3import torch
4
5
6class RandomScaledSGD(torch.optim.Optimizer):
7 def __init__(
8 self,
9 scaling_factor: float = 0.01,
10 p: float = 1.0,
11 generator_seed: int = None,
12 *args,
13 **kwargs,
14 ) -> None:
15 """Randomized Scaled SGD optimizer. Randomly scales the learning rate.
16
17 Args:
18 scaling_factor: Learning rate scaling factor. Defaults to 1.0.
19 p: Probability of scaling the learning rate. Defaults to 1.0.
20 generator_seed: Seed for the random number generator.
21 Defaults to None.
22 """
23 super().__init__(*args, **kwargs)
24 self.scaling_factor = scaling_factor
25 self.p = p
26 self.g = torch.Generator()
27 self.base_lr = self.param_groups[0]["lr"]
28 if generator_seed:
29 self.g.manual_seed(generator_seed)
30
31 def custom_step(
32 self,
33 model: torch.nn.Module, # model
34 data: torch.Tensor, # batched input data
35 target: torch.Tensor, # batched target data
36 criterion: torch.nn.Module, # loss function
37 probabilities_fn: Callable, # function to get probabilities from model outputs
38 ) -> Tuple[float, torch.Tensor]:
39 self.zero_grad()
40 output = model(data)
41 loss = criterion(probabilities_fn(output), target)
42 loss.backward()
43 if torch.rand(1, generator=self.g).item() < self.p:
44 self.param_groups[0]["lr"] *= self.scaling_factor
45 self.step()
46 self.param_groups[0]["lr"] = self.base_lr
47 return loss.item(), output.detach()
The following configuration creates a new RandomScaledSGD.yaml
optimizer in the conf/optimizer/
directory and uses the global
seed of the main configuration as the generator_seed
attribute:
1id: RandomScaledSGD
2_target_: random_scaled_sgd.RandomScaledSGD
3
4scaling_factor: 0.001
5p: 0.05
6generator_seed: ${seed}
Note
The params
and lr
attributes are automatically passed to the optimizer during initialization
and determined at runtime.
Custom Schedulers#
To create a custom scheduler, inherit from
torch.optim.lr_scheduler.LRScheduler
and implement the get_lr()
method.
For example, the following scheduler implements a simple linear warm-up scheduler:
1from typing import List
2
3import torch
4from torch.optim.lr_scheduler import LRScheduler
5
6
7class LinearWarmUpLR(LRScheduler):
8 def __init__(
9 self,
10 optimizer: torch.optim.Optimizer,
11 warmup_steps: int,
12 last_epoch: int = -1,
13 ) -> None:
14 """Linear warm-up learning rate scheduler.
15
16 Args:
17 optimizer: Wrapped optimizer.
18 warmup_steps: Number of warmup steps.
19 last_epoch: The index of last epoch. Defaults to -1.
20 """
21 self.warmup_steps = warmup_steps
22 super().__init__(optimizer, last_epoch)
23
24 def get_lr(self) -> List[float]:
25 if self.last_epoch < self.warmup_steps:
26 return [
27 base_lr * (self.last_epoch + 1) / self.warmup_steps
28 for base_lr in self.base_lrs
29 ]
30 return self.base_lrs
The following configuration creates a new LinearWarmUpLR.yaml
scheduler with a linear warm-up period of 10
training iterations in the conf/scheduler/
directory:
1id: LinearWarmUpLR
2_target_: linear_warm_up_lr.LinearWarmUpLR
3
4warmup_steps: 10
5
6step_frequency: evaluation
Note
The optimizer
attribute is automatically passed to the scheduler during initialization and determined at runtime.
Custom Transforms#
To create a custom transform, inherit from AbstractTransform
and implement the __call__()
method.
For example, the following transform denoises a spectrogram by applying a median filter:
1import scipy.ndimage
2import torch
3
4from autrainer.transforms import AbstractTransform
5
6
7class SpectMedianFilter(AbstractTransform):
8 def __init__(self, size: int, order: int = 0) -> None:
9 """Spectrogram median filter to remove noise.
10
11 Args:
12 size: Number of neighboring pixels to consider when filtering.
13 Must be odd.
14 order: The order of the transform in the pipeline. Larger means
15 later in the pipeline. If multiple transforms have the same
16 order, they are applied in the order they were added to the
17 pipeline. Defaults to 0.
18 """
19 super().__init__(order=order)
20 self.size = size
21
22 def __call__(self, x: torch.Tensor) -> torch.Tensor:
23 return torch.from_numpy(
24 scipy.ndimage.median_filter(
25 x.cpu().numpy(),
26 size=self.size,
27 )
28 ).to(x.device)
This transform can be used both as a preprocessing transform and as an online transform.
Custom Preprocessing Transforms#
To create a custom preprocessing transform, create a new file in the conf/preprocessing/
directory.
For example, the following preprocessing transform extracts log-Mel spectrograms from audio data at a sampling rate of 32 kHz and applies the custom denoising transform to the data:
1file_handler:
2 autrainer.datasets.utils.AudioFileHandler:
3 target_sample_rate: 32000
4pipeline:
5 - autrainer.transforms.StereoToMono
6 - autrainer.transforms.PannMel:
7 sample_rate: 32000
8 window_size: 1024
9 hop_size: 320
10 mel_bins: 64
11 fmin: 50
12 fmax: 14000
13 ref: 1.0
14 amin: 1e-10
15 top_db: null
16 - spect_median_filter.SpectMedianFilter:
17 size: 5
Any audio dataset can utilize this preprocessing transform by specifying the features_subdir
attribute in the dataset configuration and adjusting the file_type
, file_handler
, and transform
attributes:
1...
2features_subdir: denoised_log_mel_32k
3file_type: npy
4file_handler: autrainer.datasets.utils.NumpyFileHandler
5...
6transform:
7 type: grayscale
Custom Online Transforms#
To create a custom online transform, no configuration file is necessary as the transform is applied at runtime
and specified in the transform
attribute of the model and dataset configurations using shorthand syntax.
For example, the following configuration applies the custom denoising transform to the data at runtime:
1...
2transform:
3 type: grayscale
4 base:
5 - spect_median_filter.SpectMedianFilter:
6 size: 5
In line with the custom preprocessing transform example, the custom denoising transform is applied to the train, dev, and test datasets.
It may be desirable to only apply a transform to a specific subset of the data.
The following configuration applies the custom denoising transform only to the train
subset of the data:
1...
2transform:
3 type: grayscale
4 train:
5 - spect_median_filter.SpectMedianFilter:
6 size: 5
Custom Augmentations#
To create a custom augmentation, inherit from AbstractAugmentation
and implement the apply()
method.
For example, the following augmentation scales the amplitude of a spectrogram by a random factor in a given range:
1from typing import Optional, Tuple
2
3import torch
4
5from autrainer.augmentations import AbstractAugmentation
6
7
8class AmplitudeScale(AbstractAugmentation):
9 def __init__(
10 self,
11 scale_range: Tuple[float, float],
12 order: int = 0,
13 p: float = 1.0,
14 generator_seed: Optional[int] = None,
15 ) -> None:
16 """Amplitude scaling augmentation. The amplitude is randomly scaled by
17 a factor drawn from scale_range.
18
19 Args:
20 scale_range: The range of the amplitude scaling factor.
21 order: The order of the augmentation in the transformation pipeline.
22 Defaults to 0.
23 p: The probability of applying the augmentation. Defaults to 1.0.
24 generator_seed: The initial seed for the internal random number
25 generator drawing the probability. If None, the generator is
26 not seeded. Defaults to None.
27
28 Raises:
29 ValueError: If p is not in the range [0, 1].
30 """
31 super().__init__(order, p, generator_seed)
32 self.scale_range = scale_range
33 self.scale_g = torch.Generator()
34 if self.generator_seed:
35 self.scale_g.manual_seed(self.generator_seed)
36
37 def apply(self, x: torch.Tensor, index: int = None) -> torch.Tensor:
38 s0, s1 = self.scale_range
39 return x * (torch.rand(1, generator=self.scale_g) * (s1 - s0) + s0)
The following configuration creates a new AmplitudeScale.yaml
augmentation in the conf/augmentation/
directory,
scaling the amplitude of the spectrogram by a random factor between 0.8 and 1.2 with a probability p
of 0.5:
1id: AmplitudeScale
2_target_: autrainer.augmentations.AugmentationPipeline
3
4generator_seed: 0
5
6pipeline:
7 - amplitude_scale_augmentation.AmplitudeScale:
8 scale_range: [0.8, 1.2]
9 p: 0.5
As no augmentation in the pipeline
specifies a generator_seed
attribute,
the global generator_seed
attribute is broadcasted to all augmentations to ensure reproducibility.
Custom Augmentation Graphs#
For example, the following configuration creates a new AmplitudeScaleOrTimeFreqMask.yaml
augmentation in the
conf/augmentation/
directory, either applying the custom amplitude scale augmentation or a sequence of the
TimeMask
and FrequencyMask
augmentations:
1id: AmplitudeScaleOrTimeFreqMask
2_target_: autrainer.augmentations.AugmentationPipeline
3
4generator_seed: 0
5
6pipeline:
7 - autrainer.augmentations.Choice:
8 weights: [0.2, 0.8]
9 choices:
10 - amplitude_scale_augmentation.AmplitudeScale:
11 scale_range: [0.8, 1.2]
12 - autrainer.augmentations.Sequential:
13 sequence:
14 - autrainer.augmentations.TimeMask:
15 time_mask: 80
16 - autrainer.augmentations.FrequencyMask:
17 freq_mask: 10
The custom amplitude scale augmentation is selected with a probability of 0.2,
while the sequence of the TimeMask
and FrequencyMask
augmentations is selected with a probability of 0.8.
Custom Collate Augmentations#
To create a custom collate augmentation, inherit from AbstractAugmentation
and implement the optional get_collate_fn()
method.
The collate function is used to apply the augmentation on the batch level. In case the collate function modifies the shape of the input or labels, this may need to be accounted for if the augmentation is not applied.
For example, the following augmentation randomly applies CutMix
or MixUp
augmentations
on the batch level:
1from typing import TYPE_CHECKING, Callable, List, Optional, Tuple
2
3import torch
4from torch.utils.data import default_collate
5from torchvision.transforms import v2
6
7from autrainer.augmentations import AbstractAugmentation
8
9
10if TYPE_CHECKING:
11 from autrainer.datasets import AbstractDataset
12
13
14class CutMixUp(AbstractAugmentation):
15 def __init__(
16 self,
17 alpha: float = 1.0,
18 order: int = 0,
19 p: float = 1.0,
20 generator_seed: Optional[int] = None,
21 ) -> None:
22 """Randomly applies CutMix or MixUp augmentations with a probability
23 of 0.5 each.
24
25 Args:
26 alpha: Hyperparameter of the Beta distribution. Defaults to 1.0.
27 order: The order of the augmentation in the transformation pipeline.
28 Defaults to 0.
29 p: The probability of applying the augmentation. Defaults to 1.0.
30 generator_seed: The initial seed for the internal random number
31 generator drawing the probability. If None, the generator is
32 not seeded. Defaults to None.
33 """
34 super().__init__(order, p, generator_seed)
35 self.alpha = alpha
36 self.cut_mix_up_g = torch.Generator()
37 if generator_seed:
38 self.cut_mix_up_g.manual_seed(generator_seed)
39
40 def get_collate_fn(self, data: "AbstractDataset") -> Callable:
41 self.cutmix = v2.CutMix(num_classes=data.output_dim, alpha=self.alpha)
42 self.mixup = v2.MixUp(num_classes=data.output_dim, alpha=self.alpha)
43
44 def _collate_fn(
45 batch: List[Tuple[torch.Tensor, int, int]],
46 ) -> List[Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
47 probability = torch.rand(1, generator=self.g).item()
48 if probability < self.p:
49 p = torch.rand(1, generator=self.cut_mix_up_g).item()
50 if p < 0.5:
51 return self.cutmix(*default_collate(batch))
52 return self.mixup(*default_collate(batch))
53
54 # still one-hot encode the labels if no augmentation is applied
55 batched = default_collate(batch)
56 batched[1] = torch.nn.functional.one_hot(
57 batched[1], data.output_dim
58 ).float()
59 return batched
60
61 return _collate_fn
62
63 def apply(self, x: torch.Tensor, index: int = None) -> torch.Tensor:
64 # no-op as the augmentation is applied in the collate function
65 return x
Custom Loggers#
To create a custom logger, inherit from AbstractLogger
and implement the
log_params()
,
log_metrics()
,
log_timers()
,
and log_artifact()
methods, as well as the optional
setup()
, and end_run()
methods.
All methods are automatically called a the appropriate time during training and inference.
For example, the following logger logs to Weights & Biases:
1import os
2from typing import Dict, List, Union
3
4from omegaconf import DictConfig
5import wandb
6
7from autrainer.core.constants import ExportConstants
8from autrainer.loggers import (
9 AbstractLogger,
10 get_params_to_export,
11)
12from autrainer.metrics import AbstractMetric
13
14
15class WandBLogger(AbstractLogger):
16 def __init__(
17 self,
18 exp_name: str,
19 run_name: str,
20 metrics: List[AbstractMetric],
21 tracking_metric: AbstractMetric,
22 artifacts: List[
23 Union[str, Dict[str, str]]
24 ] = ExportConstants().ARTIFACTS,
25 output_dir: str = "wandb",
26 ) -> None:
27 super().__init__(
28 exp_name, run_name, metrics, tracking_metric, artifacts
29 )
30 if not os.path.isabs(output_dir):
31 output_dir = os.path.join(os.getcwd(), output_dir)
32 os.makedirs(output_dir, exist_ok=True)
33 self.output_dir = output_dir
34
35 def log_params(self, params: Union[dict, DictConfig]) -> None:
36 wandb.init(
37 project=self.exp_name,
38 name=self.run_name,
39 config=get_params_to_export(params),
40 dir=self.output_dir,
41 )
42
43 def log_metrics(
44 self,
45 metrics: Dict[str, Union[int, float]],
46 iteration=None,
47 ) -> None:
48 wandb.log(metrics, step=iteration)
49
50 def log_timers(self, timers: Dict[str, float]) -> None:
51 wandb.log(timers)
52
53 def log_artifact(self, filename: str, path: str = "") -> None:
54 artifact = wandb.Artifact(name=filename, type="model")
55 artifact.add_file(os.path.join(path, filename))
56 wandb.log_artifact(artifact)
57
58 def end_run(self) -> None:
59 wandb.finish()
Note that the WandBLogger
assumes that wandb is installed, the API key is set,
and a project with the same name as the experiment_id
of the main configuration exists.
To add the WandBLogger
, specify it in the main configuration by adding a list of loggers
:
1...
2loggers:
3 - wandb_logger.WandBLogger:
4 output_dir: ${results_dir}/.wandb
5...
Custom Callbacks#
To create a custom callback, implement a class that specifies any of the callback functions defined in
CallbackSignature
.
For example, the following callback tracks learning rate changes at the beginning of each iteration:
1from typing import TYPE_CHECKING
2
3
4if TYPE_CHECKING:
5 from autrainer.training import ModularTaskTrainer
6
7
8class LRTrackerCallback:
9 def cb_on_train_begin(self, trainer: "ModularTaskTrainer") -> None:
10 self.lr = trainer.optimizer.param_groups[0]["lr"]
11
12 def cb_on_iteration_begin(
13 self,
14 trainer: "ModularTaskTrainer",
15 iteration: int,
16 ) -> None:
17 current_lr = trainer.optimizer.param_groups[0]["lr"]
18 if current_lr != self.lr:
19 print(
20 f"Learning rate changed from {self.lr} "
21 f"to {current_lr} in iteration {iteration}."
22 )
23 self.lr = current_lr
To add the LRTrackerCallback
, specify it in the main configuration by adding a list of callbacks
:
1...
2callbacks:
3 - lr_tracker_callback.LRTrackerCallback
4...
Custom Plotting#
To create a custom plotting configuration, create a new file in the conf/plotting/
directory.
For example, the following configuration uses the LaTeX backend,
the Palatino font with a font size of 9, replaces None
values in the run name with for better readability,
and adds labels as well as titles to the plot.
1figsize: [10, 5] # figure size in inches
2latex: true # use LaTeX for text rendering
3filetypes: [png, pdf] # save figures in these formats
4pickle: true # save the figure data in a pickle file
5context: notebook # seaborn context
6palette: colorblind # seaborn color palette
7replace_none: true # replace None with ~
8add_titles: true
9add_xlabels: true
10add_ylabels: true
11
12rcParams:
13 font.serif: Palatino # LaTeX font
14 font.family: serif
15 legend.fontsize: 9
To add the LaTeX.yaml
plotting configuration, specify it in the main configuration
by overriding the plotting
attribute:
1defaults:
2 - ...
3 - override plotting: LaTeX
4...