Inference#

Inference offers an interface to obtain predictions or embeddings from a trained model. In addition, a sliding window can be used to obtain predictions or embeddings from parts of the input data.

The autrainer inference CLI command and the inference() CLI wrapper function allow for the (sliding window) inference of audio data using a trained model.

Note

Currently, inference is only supported for audio data.

Audio Inference#

class autrainer.serving.Inference(model_path, checkpoint='_best', device='cpu', preprocess_cfg='default', window_length=None, stride_length=None, min_length=None, sample_rate=None)[source]#

Inference class for audio models.

Parameters:
  • model_path (str) – Local model directory containing the model.yaml, file_handler.yaml, target_transform.yaml, and inference_transform.yaml files.

  • checkpoint (str) – Checkpoint directory containing a model.pt file. Defaults to “_best”.

  • device (str) – Device to run inference on. Defaults to “cpu”.

  • preprocess_cfg (Optional[str]) – Preprocessing configuration file. If “default”, the default preprocessing pipeline used during training is applied. If None, no preprocessing is applied. Defaults to “default”.

  • window_length (Optional[float]) – Window length in seconds for sliding window inference. Defaults to None.

  • stride_length (Optional[float]) – Stride length in seconds for sliding window inference. Defaults to None.

  • min_length (Optional[float]) – Minimum length of an audio file in seconds. Audio files shorter than this will be padded with zeros. Defaults to None.

  • sample_rate (Optional[int]) – Sample rate of the audio files in Hz. Defaults to None.

predict_directory(directory, extension, recursive=False)[source]#

Obtain the model predictions for all files in a directory.

Parameters:
  • directory (str) – Path to the directory containing audio files.

  • extension (str) – File extension of the audio files.

  • recursive (bool) – Whether to search recursively for audio files in subdirectories. Defaults to False.

Return type:

DataFrame

Returns:

DataFrame containing the filename, prediction, and output for each file. If sliding window inference is used, the offset is additionally included.

embed_directory(directory, extension, recursive=False)[source]#

Obtain the model embeddings for all files in a directory.

Parameters:
  • directory (str) – Path to the directory containing audio files.

  • extension (str) – File extension of the audio files.

  • recursive (bool) – Whether to search recursively for audio files in subdirectories. Defaults to False.

Return type:

DataFrame

Returns:

DataFrame containing the filename and embedding for each file. If sliding window inference is used, the offset is additionally included.

predict_file(file)[source]#

Obtain the model prediction for a single file.

Parameters:

file (str) – Path to the audio file.

Return type:

Union[Tuple[Any, Tensor], Tuple[Dict[str, Any], Dict[str, Tensor]]]

Returns:

Model prediction and output for the file. If sliding window inference is used, the prediction is a dictionary with the offset as the key.

embed_file(file)[source]#

Obtain the model embedding for a single file.

Parameters:

file (str) – Path to the audio file.

Return type:

Union[Tensor, Dict[str, Tensor]]

Returns:

Model embedding for the file. If sliding window inference is used, the embedding is a dictionary with the offset as the key.

static save_prediction_yaml(results, output_dir)[source]#

Save the prediction results to a YAML file.

Creates a human-readable YAML file with the prediction results.

Parameters:
  • results (DataFrame) – DataFrame containing the results.

  • output_dir (str) – Output directory to save the results to.

Return type:

None

static save_prediction_results(results, output_dir)[source]#

Save the prediction results to a CSV file.

Creates a CSV file with the model predictions and outputs.

Parameters:
  • results (DataFrame) – DataFrame containing the results.

  • output_dir (str) – Output directory to save the results to.

Return type:

None

static save_embeddings(results, output_dir, input_extension)[source]#

Save the embeddings as torch tensors.

Saves the embeddings to the output directory with the same filename as each audio file.

Parameters:
  • results (DataFrame) – DataFrame containing the embeddings.

  • output_dir (str) – Output directory to save the embeddings to.

  • input_extension (str) – File extension of the input audio files to replace with “.pt”.

Return type:

None