CLI Wrapper#

autrainer provides an autrainer.cli CLI wrapper to programmatically manage the entire training process, including configuration management, data preprocessing, model training, inference, and postprocessing.

Wrapper functions are useful for integrating autrainer into custom scripts, jupyter notebooks, google colab notebooks, and other applications.

In addition to the CLI wrapper functions, autrainer provides a CLI to manage configurations, data, training, inference, and postprocessing from the command line with the same functionality as the CLI wrapper.

Configuration Management#

To manage configurations, create(), list(), and show() allow for the creation of the project structure and the discovery as well as saving of default configurations provided by autrainer.

Tip

Default configurations can be discovered both through the CLI, the CLI wrapper, and the respective module documentation.

autrainer.cli.create(directories=None, empty=False, all=False, force=False)[source]#

Create a new project with default configurations.

If called in a notebook, the function will not raise an error and print the error message instead.

Parameters:
  • directories (Optional[List[str]]) – Configuration directories to create. One or more of: CONFIG_DIRS. Defaults to None.

  • empty (bool) – Create an empty project without any configuration directory. Defaults to False.

  • all (bool) – Create a project with all configuration directories. Defaults to False.

  • force (bool) – Force overwrite if the configuration directory already exists. Defaults to False.

Raises:
  • CommandLineError – If no configuration directories are specified and neither the empty nor all flags are set.

  • CommandLineError – If the empty and all flags are set at the same time.

  • CommandLineError – If the empty or all flags are set in combination with configuration directories.

  • CommandLineError – If the configuration directory already exists and the force flag is not set.

Return type:

None

autrainer.cli.list(directory, local_only=False, global_only=False, pattern='*')#

List local and global configurations.

If called in a notebook, the function will not raise an error and print the error message instead.

Parameters:
  • directory (str) – The directory to list configurations from. Choose from: CONFIG_DIRS.

  • local_only (bool) – List local configurations only. Defaults to False.

  • global_only (bool) – List global configurations only. Defaults to False.

  • pattern (str) – Glob pattern to filter configurations. Defaults to “*”.

Raises:

CommandLineError – If the local configuration directory does not exist and local_only is True.

Return type:

None

autrainer.cli.show(directory, config, save=False, force=False)[source]#

Show and save a global configuration.

If called in a notebook, the function will not raise an error and print the error message instead.

Parameters:
  • directory (str) – The directory to list configurations from. Choose from: CONFIG_DIRS.

  • config (str) – The global configuration to show. Configurations can be discovered using the ‘autrainer list’ command.

  • save (bool) – Save the global configuration to the local conf directory. Defaults to False.

  • force (bool) – Force overwrite local configuration if it exists in combination with save=True. Defaults to False.

Raises:
  • CommandLineError – If the global configuration does not exist.

  • CommandLineError – If while saving the local configuration, the configuration already exists and force is not set.

Return type:

None

Preprocessing#

To avoid race conditions when using Launcher Plugins that may run multiple training jobs in parallel, fetch() and preprocess() allow for downloading and preprocessing of Datasets (and pretrained model states) before training.

Both commands are based on the main configuration file (e.g. conf/config.yaml), such that the specified models and datasets are fetched and preprocessed accordingly. If a model or dataset is already fetched or preprocessed, it will be skipped.

autrainer.cli.fetch(override_kwargs=None, cfg_launcher=False, config_name='config', config_path=None)[source]#

Fetch the datasets and models specified in a training configuration.

Parameters:
  • override_kwargs (Optional[dict]) – Additional Hydra override arguments to pass to the train script.

  • cfg_launcher (bool) – Use the launcher specified in the configuration instead of the Hydra basic launcher. Defaults to False.

  • config_name (str) – The name of the config (usually the file name without the .yaml extension). Defaults to “config”.

  • config_path (Optional[str]) – The config path, a directory where Hydra will search for config files. If config_path is None no directory is added to the search path. Defaults to None.

Return type:

None

autrainer.cli.preprocess(override_kwargs=None, num_workers=1, update_frequency=1, cfg_launcher=False, config_name='config', config_path=None)[source]#

Launch a data preprocessing configuration.

Parameters:
  • override_kwargs (Optional[dict]) – Additional Hydra override arguments to pass to the train script.

  • num_workers (int) – Number of workers (threads) to use for preprocessing. Defaults to 1.

  • update_frequency (int) – Frequency of progress bar updates for each worker (thread). If 0, the progress bar will be disabled. Defaults to 1.

  • cfg_launcher (bool) – Use the launcher specified in the configuration instead of the Hydra basic launcher. Defaults to False.

  • config_name (str) – The name of the config (usually the file name without the .yaml extension). Defaults to “config”.

  • config_path (Optional[str]) – The config path, a directory where Hydra will search for config files. If config_path is None no directory is added to the search path. Defaults to None.

Return type:

None

Training#

Training is managed by train(), which starts the training process based on the main configuration file (e.g. conf/config.yaml).

autrainer.cli.train(override_kwargs=None, config_name='config', config_path=None)[source]#

Launch a training configuration.

Parameters:
  • override_kwargs (Optional[dict]) – Additional Hydra override arguments to pass to the train script.

  • config_name (str) – The name of the config (usually the file name without the .yaml extension). Defaults to “config”.

  • config_path (Optional[str]) – The config path, a directory where Hydra will search for config files. If config_path is None no directory is added to the search path. Defaults to None.

Return type:

None

Inference#

inference() allows for the (sliding window) inference of audio data using a trained model.

Both local paths and Hugging Face Hub links are supported for the model. Hugging Face Hub links are automatically downloaded and cached in the torch cache directory.

The following syntax is supported for Hugging Face Hub links: hf:repo_id[@revision][:subdir]#local_dir. This syntax consists of the following components:

  • hf: The Hugging Face Hub prefix indicating that the model is fetched from the Hugging Face Hub.

  • repo_id: The repository ID of the model consisting of the user name and the model card name separated by a slash (e.g. autrainer/example).

  • revision (optional): The revision as a commit hash, branch name, or tag name (e.g. main). If not specified, the latest revision is used.

  • subdir (optional): The subdirectory of the repository containing the model directory (e.g. AudioModel). If not specified, the model directory is automatically inferred. If multiple models are present in the repo_id, subdir must be specified, as the correct model cannot be automatically inferred.

  • local_dir (optional): The local directory to which the model is downloaded to (e.g. .hf_local). If not specified, the model is placed in the torch hub cache directory.

For example, to download the model from the repository autrainer/example at the revision main from the subdirectory AudioModel and save it to the local directory .hf_local, the following inference() CLI wrapper function can be used:

import autrainer.cli

autrainer.cli.inference(
    model="hf:autrainer/example@main:AudioModel#.hf_local",
    input="input",
    output="output",
    device="cuda:0",
)

Tip

To access private repositories, the environment variable HF_HOME should point to the Hugging Face User Access Token.

To use a custom endpoint (e.g. for a self-hosted hub), the environment variable HF_ENDPOINT should point to the desired endpoint URL.

To use a local model path, the following inference() CLI wrapper function can be used:

import autrainer.cli

autrainer.cli.inference(
    model="/path/to/AudioModel",
    input="input",
    output="output",
    device="cuda:0",
)
autrainer.cli.inference(model, input, output, checkpoint='_best', device='cpu', extension='wav', recursive=False, embeddings=False, update_frequency=1, preprocess_cfg='default', window_length=None, stride_length=None, min_length=None, sample_rate=None)[source]#

Perform inference on a trained model.

If called in a notebook, the function will not raise an error and print the error message instead.

Parameters:
  • model (str) – Local path to model directory or Hugging Face link of the format: hf:repo_id[@revision][:subdir]#local_dir. Should contain at least one state subdirectory, the model.yaml, file_handler.yaml, target_transform.yaml, and inference_transform.yaml files.

  • input (str) – Path to input directory. Should contain audio files of the specified extension.

  • output (str) – Path to output directory. Output includes a YAML file with predictions and a CSV file with model outputs.

  • checkpoint (str) – Checkpoint to use for evaluation. Defaults to ‘_best’.

  • device (str) – CUDA-enabled device to use for processing. Defaults to ‘cpu’.

  • extension (str) – Type of file to look for in the input directory. Defaults to ‘wav’.

  • recursive (bool) – Recursively search for files in the input directory. Defaults to False.

  • embeddings (bool) – Extract embeddings from the model in addition to predictions. For each file, a .pt file with embeddings will be saved. Defaults to False.

  • update_frequency (int) – Frequency of progress bar updates. If 0, the progress bar will be disabled. Defaults to 1.

  • preprocess_cfg (Optional[str]) – Preprocessing configuration to apply to input. Can be a path to a YAML file or the name of the preprocessing configuration in the local or autrainer ‘conf/preprocessing’ directory. If “default”, the default preprocessing configuration used during training will be applied. If None, no preprocessing will be applied. Defaults to “default”.

  • window_length (Optional[float]) – Window length for sliding window inference in seconds. If None, the entire input will be processed at once. Defaults to None.

  • stride_length (Optional[float]) – Stride length for sliding window inference in seconds. If None, the entire input will be processed at once. Defaults to None.

  • min_length (Optional[float]) – Minimum length of audio file to process in seconds. Files shorter than the minimum length are padded with zeros. Sample rate has to be specified for padding. If None, no minimum length is enforced. Defaults to None.

  • sample_rate (Optional[int]) – Sample rate of audio files in Hz. Has to be specified for sliding window inference. Defaults to None.

Raises:

CommandLineError – If the model, input, or preprocessing configuration does not exist, or if the device is invalid.

Return type:

None

Postprocessing#

Postprocessing allows for the summarization, visualization, and aggregation of the training results using postprocess(). Several cleanup utilities are provided by rm_failed() and rm_states(). Manual grouping of the training results can be done using group().

autrainer.cli.postprocess(results_dir, experiment_id, max_runs=None, aggregate=None)[source]#

Postprocess grid search results.

If called in a notebook, the function will not raise an error and print the error message instead.

Parameters:
  • results_dir (str) – Path to grid search results directory.

  • experiment_id (str) – ID of experiment to postprocess.

  • max_runs (Optional[int]) – Maximum number of best runs to plot. Defaults to None.

  • aggregate (Optional[List[List[str]]]) – Configurations to aggregate. One or more of: VALID_AGGREGATIONS. Defaults to None.

Raises:

CommandLineError – If the results directory or experiment ID dont exist.

Return type:

None

autrainer.cli.rm_failed(results_dir, experiment_id, force=False)[source]#

Delete failed runs from an experiment.

If called in a notebook, the function will not raise an error and print the error message instead.

Parameters:
  • results_dir (str) – Path to grid search results directory.

  • experiment_id (str) – ID of experiment to postprocess.

  • force (bool) – Force deletion of failed runs without confirmation. Defaults to False.

Raises:

CommandLineError – If the results directory or experiment ID dont exist.

Return type:

None

autrainer.cli.rm_states(results_dir, experiment_id, keep_best=True, keep_runs=None, keep_iterations=None)[source]#

Delete states (.pt files) from an experiment.

If called in a notebook, the function will not raise an error and print the error message instead.

Parameters:
  • results_dir (str) – Path to grid search results directory.

  • experiment_id (str) – ID of experiment to postprocess.

  • keep_best (bool) – Keep best states. Defaults to True.

  • keep_runs (Optional[List[str]]) – Runs to keep. Defaults to None.

  • keep_iterations (Optional[List[int]]) – Iterations to keep. Defaults to None.

Raises:

CommandLineError – If the results directory or experiment ID dont exist.

Return type:

None

autrainer.cli.group(override_kwargs=None, config_name='group', config_path=None)[source]#

Launch a manual grouping of multiple grid search results.

Parameters:
  • override_kwargs (Optional[dict]) – Additional Hydra override arguments to pass to the train script.

  • config_name (str) – The name of the config (usually the file name without the .yaml extension). Defaults to “group”.

  • config_path (Optional[str]) – The config path, a directory where Hydra will search for config files. If config_path is None no directory is added to the search path. Defaults to None.

Return type:

None