
Core provides various utilities and entry points for the autrainer framework.

Entry Point#

The main training entry point for autrainer.

autrainer.main(config_name, config_path=None, version_base=None)[source]#

Hydra main decorator with additional autrainer configs.

The conf directory in the current working directory is always added to the search path if it exists. The current working directory is also added to the Python path.

  • config_name (str) – The name of the config (usually the file name without the .yaml extension).

  • config_path (Optional[str]) – The config path, a directory where Hydra will search for config files. If config_path is None no directory is added to the search path. Defaults to None.

  • version_base (Optional[str]) – Hydra version base. Defaults to None.


Instantiation functions provide wrappers around Hydra object instantiation, providing additional type safety and Shorthand Syntax support.

autrainer.instantiate(config, instance_of=None, convert=None, recursive=False, **kwargs)[source]#

Instantiate an object from a configuration Dict or DictConfig.

The config must contain a _target_ field that specifies a relative import path to the object to instantiate. If _target_ is None, returns None.

  • config (Union[DictConfig, Dict]) – The configuration to instantiate.

  • instance_of (Optional[Type[TypeVar(T)]]) – The expected type of the instantiated object. Defaults to None.

  • convert (Optional[HydraConvertEnum]) – The conversion strategy to use, one of HydraConvertEnum. Convert is only used if the config does not have a _convert_ attribute. If None, uses HydraConvertEnum.ALL. Defaults to None.

  • recursive (bool) – Whether to recursively instantiate objects. Recursive is only used if the config does not have a _recursive_ field. Defaults to False.

  • **kwargs – Additional keyword arguments to pass to the object.

  • ValueError – If the config does not have a _target_ field.

  • ValueError – If the instantiated object is not an instance of instance_of and instance_of is provided.

autrainer.instantiate_shorthand(config, instance_of=None, convert=None, recursive=False, **kwargs)[source]#

Instantiate an object from a shorthand configuration.

A shorthand config is either a string or a dictionary with a single key. If config is a string, it should be a python import path. If config is a dictionary, the key should be a python import path and the value should be a dictionary of keyword arguments.

  • config (Union[str, DictConfig, Dict]) – The config to instantiate.

  • instance_of (Optional[Type[TypeVar(T)]]) – The expected type of the instantiated object. Defaults to None.

  • convert (Optional[HydraConvertEnum]) – The conversion strategy to use, one of HydraConvertEnum. Convert is only used if the config does not have a _convert_ attribute. If None, uses HydraConvertEnum.ALL. Defaults to None.

  • recursive (bool) – Whether to recursively instantiate objects. Recursive is only used if the config does not have a _recursive_ field. Defaults to False.

  • **kwargs – Additional keyword arguments to pass to the object.


ValueError – If the config is empty (None or an empty string/dictionary).

Utils provide various helpers for I/O, logging, timing, and hardware information.

class autrainer.core.utils.Bookkeeping(output_directory, file_handler_path=None)[source]#

Bookkeeping to handle general disk operations and interactions.

  • output_directory (str) – Output directory to save files to.

  • file_handler_path (Optional[str]) – Path to save the log file to. Defaults to None.

log(message, level=20)[source]#

Log a message.

  • message (str) – Message to log.

  • level (int) – Logging level. Defaults to logging.INFO.

log_to_file(message, level=20)[source]#

Log a message to the file handler.

  • message (str) – Message to log.

  • level (int) – Logging level. Defaults to logging.INFO.

create_folder(folder_name, path='')[source]#

Create a new folder in the output directory.

  • folder_name (str) – Name of the folder to create.

  • path (str) – Subdirectory to create the folder in. Defaults to “”.

save_model_summary(model, dataset, filename)[source]#

Save a model summary to a file.

  • model (Module) – Model to summarize.

  • dataset (AbstractDataset) – Dataset to get the input size from.

  • filename (str) – Name of the file to save the summary to.

save_state(obj, filename, path='')[source]#

Save the state of an object.

  • obj (Union[Module, Optimizer, LRScheduler]) – Object to save the state of.

  • filename (str) – Name of the file to save the state to.

  • path (str) – Subdirectory to save the state to. Defaults to “”.


TypeError – If the object type is not supported.

load_state(obj, filename, path='')[source]#

Load the state of an object.

  • obj (Union[Module, Optimizer, LRScheduler]) – Object to load the state into.

  • filename (str) – Name of the file to load the state from.

  • path (str) – Subdirectory to load the state from. Defaults to “”.

  • TypeError – If the object type is not supported.

  • FileNotFoundError – If the file is not found.

save_audobject(obj, filename, path='')[source]#

Save an audobject.Object to disk.

  • obj (Object) – Object to save.

  • filename (str) – Name of the file to save the object to.

  • path (str) – Subdirectory to save the object to. Defaults to “”.


TypeError – If the object type is not supported.

save_results_dict(results_dict, filename, path='')[source]#

Save a results dictionary to disk.

  • results_dict (Dict[str, float]) – Dictionary of metric names and values to save.

  • filename (str) – Name of the file to save the results to.

  • path (str) – Subdirectory to save the results to. Defaults to “”.

save_results_df(results_df, filename, path='')[source]#

Save a results DataFrame to disk.

  • results_df (DataFrame) – DataFrame to save.

  • filename (str) – Name of the file to save the results to.

  • path (str) – Subdirectory to save the results to. Defaults to “”.

save_results_np(results_np, filename, path='')[source]#

Save a results numpy array to disk.

  • results_np (ndarray) – Numpy array to save.

  • filename (str) – Name of the file to save the results to.

  • path (str) – Subdirectory to save the results to. Defaults to “”.

save_best_results(metrics, filename, metric_fns, tracking_metric_fn, path='')[source]#

Save the best results to disk.

  • metrics (DataFrame) – DataFrame of metrics to save.

  • filename (str) – Name of the file to save the best results to.

  • metric_fns (List[AbstractMetric]) – List of metric functions to get the best results from.

  • tracking_metric_fn (AbstractMetric) – Tracking metric function to get the best iteration from.

  • path (str) – Subdirectory to save the best results to. Defaults to “”.

class autrainer.core.utils.Timer(output_directory, timer_type)[source]#

Timer to measure time of different parts of the training process.

  • output_directory (str) – Directory to save the timer.yaml file to.

  • timer_type (str) – Name of the timer.


Start the timer.

Stop the timer.


ValueError – If the timer was not started.

Get the time log.

Return type:



Get the mean time in seconds.

Return type:



Get the total time in seconds.

Return type:



classmethod pretty_time(seconds)[source]#

Convert seconds to a pretty string.


seconds (float) – Time in seconds.

Time in a pretty string format.


Save and append the timer to timer.yaml.


path (str) – Subdirectory to save the timer.yaml file to relative to the output directory. Defaults to “”.

Dictionary with mean and total time in seconds and pretty format.


Get hardware information of the current system.


device (device) – Device to get the hardware information from.

Return type:



autrainer.core.utils.save_hardware_info(output_directory, device)[source]#

Save hardware information to a hardware.yaml file.

  • output_directory (str) – Directory to save the hardware information to.

  • device (device) – Device to get the hardware information from.

Set a global seed for reproducibility for random, numpy, and torch.

If CUDA is available, set the seed for CUDA and cuDNN as well.


seed (int) – Seed to set.

Context manager to suppress stdout and stderr.



Plotting provides a simple interface to plot metrics of a single run during Training as well as multiple runs during Postprocessing.


To create custom plotting configurations, refer to the custom plotting configurations tutorial.

By default, training plots are saved as png files for each metric. This can optionally be extended to any format supported by Matplotlib and additionally pickled for further processing.


Plots are fully customizable by providing Matplotlib rcParams in a custom plotting configuration.

class autrainer.core.plotting.PlotBase(output_directory, training_type, figsize, latex, filetypes, pickle, context, palette, replace_none, add_titles, add_xlabels, add_ylabels, rcParams)[source]#

Base class for plotting.

  • output_directory (str) – Output directory to save plots to.

  • training_type (str) – Type of training in [“Epoch”, “Step”].

  • figsize (tuple) – Figure size in inches.

  • latex (bool) – Whether to use LaTeX in plots. Requires the latex package. To install all necessary dependencies, run: pip install autrainer[latex].

  • filetypes (list) – Filetypes to save plots as.

  • pickle (bool) – Whether to save additional pickle files of the plots.

  • context (str) – Context for seaborn plots.

  • palette (str) – Color palette for seaborn plots.

  • replace_none (bool) – Whether to replace “None” in labels with “~”.

  • add_titles (bool) – Whether to add titles to plots.

  • add_xlabels (bool) – Whether to add x-labels to plots.

  • add_ylabels (bool) – Whether to add y-labels to plots.

  • rcParams (dict) – Additional Matplotlib rcParams to set.

save_plot(fig, name, path='', close=True, tight_layout=True)[source]#

Save a plot to the output directory.

  • fig (Figure) – Matplotlib figure to save.

  • name (str) – Name of the plot.

  • path (str) – Path to save the plot to relative to the output directory.

  • close (bool) – Whether to close the figure after saving.

  • tight_layout (bool) – Whether to apply tight layout to the plot.

class autrainer.core.plotting.PlotMetrics(output_directory, training_type, figsize, latex, filetypes, pickle, context, palette, replace_none, add_titles, add_xlabels, add_ylabels, rcParams, metric_fns)[source]#

Plot the metrics of one or multiple runs.

  • output_directory (str) – Output directory to save plots to.

  • training_type (str) – Type of training in [“Epoch”, “Step”].

  • figsize (tuple) – Figure size in inches.

  • latex (bool) – Whether to use LaTeX in plots. Requires the latex package. To install all necessary dependencies, run: pip install autrainer[latex].

  • filetypes (list) – Filetypes to save plots as.

  • pickle (bool) – Whether to save additional pickle files of the plots.

  • context (str) – Context for seaborn plots.

  • palette (str) – Color palette for seaborn plots.

  • replace_none (bool) – Whether to replace “None” in labels with “~”.

  • add_titles (bool) – Whether to add titles to plots.

  • add_xlabels (bool) – Whether to add x-labels to plots.

  • add_ylabels (bool) – Whether to add y-labels to plots.

  • rcParams (dict) – Additional Matplotlib rcParams to set.

  • metric_fns (List[AbstractMetric]) – List of metrics to use for plotting.

Default Configurations


 1figsize: [10, 5]
 2latex: false
 3filetypes: [png]
 4pickle: false
 5context: notebook
 6palette: colorblind
 7replace_none: false
 8add_titles: true
 9add_xlabels: true
10add_ylabels: true
13  legend.fontsize: 9
plot_run(metrics, std_scale=0.1)[source]#

Plot the metrics of a single run.

  • metrics (DataFrame) – DataFrame containing the metrics.

  • std_scale (float) – Scale factor for the standard deviation. Defaults to 0.1.

plot_metric(metrics, metric, metrics_std=None, std_scale=0.1, max_runs=None)[source]#

Plot a single metric of multiple runs.

  • metrics (DataFrame) – DataFrame containing the metrics.

  • metric (str) – Metric to plot.

  • metrics_std (Optional[DataFrame]) – DataFrame containing the standard deviations. Defaults to None.

  • std_scale (float) – Scale factor for the standard deviation. Defaults to 0.1.

  • max_runs (Optional[int]) – Maximum number of best runs to plot. If None, all runs are plotted. Defaults to None.

plot_aggregated_bars(metrics_df, metric, subplots_by=0, group_by=1, split_subgroups=True)[source]#

Plot aggregated bar plots for a metric.

Generate a bar plots from the metrics_df, which are divided by the “subplots_by” column, further grouped according to the “group_by” column. If “split_subgroups” is set to true, each group is further split into subgroups based on what comes after a potential “-” in the “group_by” entry. Finally the “metric” entries are averaged to create the bars and the standard deviation is shown as error bars.

  • metrics_df (DataFrame) – DataFrame containing the metrics.

  • metric (str) – Metric to plot.

  • subplots_by (int) – Column to group the subplots by.

  • group_by (int) – Column to group the data by.

  • split_subgroups (bool) – Whether to split subgroups.


Constants provide various predefined values.

autrainer.core.constants.NAMING_CONVENTION = ['dataset', 'model', 'optimizer', 'learning_rate', 'batch_size', 'training_type', 'iterations', 'scheduler', 'augmentation', 'seed']#

autrainer.core.constants.INVALID_AGGREGATIONS = ['training_type']#

autrainer.core.constants.VALID_AGGREGATIONS = ['scheduler', 'dataset', 'augmentation', 'seed', 'model', 'optimizer', 'learning_rate', 'batch_size', 'iterations']#

autrainer.core.constants.CONFIG_FOLDERS = ['augmentation', 'dataset', 'model', 'optimizer', 'plotting', 'preprocessing', 'scheduler', 'sharpness']#

