Source code for autrainer.core.utils.bookkeeping

import logging
import os
import sys
from typing import TYPE_CHECKING, Dict, List, Optional, Union
import warnings

import audobject
import numpy as np
import pandas as pd
import torch
from torchinfo import summary
import yaml

from autrainer.metrics import AbstractMetric


if TYPE_CHECKING:
    from autrainer.datasets import AbstractDataset  # pragma: no cover


[docs] class Bookkeeping: def __init__( self, output_directory: str, file_handler_path: Optional[str] = None, ) -> None: """Bookkeeping to handle general disk operations and interactions. Args: output_directory: Output directory to save files to. file_handler_path: Path to save the log file to. Defaults to None. """ self.output_directory = output_directory os.makedirs(self.output_directory, exist_ok=True) self.original_stdout = sys.stdout # ? Setup Custom Logging self.logger = logging.getLogger() if not self.logger.hasHandlers() or file_handler_path is not None: self._setup_logger(file_handler_path) for handler in self.logger.handlers: if isinstance(handler, logging.StreamHandler): handler.setFormatter( logging.Formatter( "[%(asctime)s][%(levelname)s]\n%(message)s\n" ) ) def _setup_logger(self, fp: Optional[str] = None) -> None: self.logger.setLevel(logging.INFO) if not self.logger.hasHandlers(): self.logger.addHandler(logging.StreamHandler()) if fp is not None: self.logger.addHandler(logging.FileHandler(fp)) else: self.logger.addHandler( logging.FileHandler( os.path.join(self.output_directory, "bookkeeping.log") ) )
[docs] def log(self, message: str, level: int = logging.INFO) -> None: """Log a message. Args: message: Message to log. level: Logging level. Defaults to logging.INFO. """ self.logger.log(level, message)
[docs] def log_to_file(self, message: str, level: int = logging.INFO) -> None: """Log a message to the file handler. Args: message: Message to log. level: Logging level. Defaults to logging.INFO. """ for handler in self.logger.handlers: if isinstance(handler, logging.FileHandler): handler.emit( logging.LogRecord( self.logger.name, level, None, None, message, None, None, ) )
[docs] def create_folder(self, folder_name: str, path: str = "") -> None: """Create a new folder in the output directory. Args: folder_name: Name of the folder to create. path: Subdirectory to create the folder in. Defaults to "". """ os.makedirs( os.path.join(self.output_directory, path, folder_name), exist_ok=True, )
[docs] def save_model_summary( self, model: torch.nn.Module, dataset: "AbstractDataset", filename: str, ) -> None: """Save a model summary to a file. Args: model: Model to summarize. dataset: Dataset to get the input size from. filename: Name of the file to save the summary to. """ x = np.expand_dims(dataset[0][0], axis=0).shape with open( os.path.join(self.output_directory, filename), "w", encoding="utf-8", ) as f: sys.stdout = f s = summary( model=model, input_size=(x), col_names=[ "input_size", "output_size", "num_params", "trainable", ], col_width=20, row_settings=["var_names"], ) sys.stdout = self.original_stdout model_summary = { "total_mult_adds": s.total_mult_adds, "total_output_bytes": s.total_output_bytes, "total_params": s.total_params, "trainable_params": s.trainable_params, "total_param_bytes": s.total_param_bytes, } with open( os.path.join( self.output_directory, filename.replace(".txt", ".yaml") ), "w", ) as f: yaml.dump(model_summary, f)
[docs] def save_state( self, obj: Union[ torch.nn.Module, torch.optim.Optimizer, torch.optim.lr_scheduler.LRScheduler, ], filename: str, path: str = "", ) -> None: """Save the state of an object. Args: obj: Object to save the state of. filename: Name of the file to save the state to. path: Subdirectory to save the state to. Defaults to "". Raises: TypeError: If the object type is not supported. """ p = os.path.join(self.output_directory, path, filename) _i = ( torch.nn.Module, torch.optim.Optimizer, torch.optim.lr_scheduler.LRScheduler, ) if not isinstance(obj, _i): raise TypeError( f"save_state of type {type(obj)} is not supported." ) os.makedirs(os.path.join(self.output_directory, path), exist_ok=True) torch.save(obj.state_dict(), p)
[docs] def load_state( self, obj: Union[ torch.nn.Module, torch.optim.Optimizer, torch.optim.lr_scheduler.LRScheduler, ], filename: str, path: str = "", ) -> None: """Load the state of an object. Args: obj: Object to load the state into. filename: Name of the file to load the state from. path: Subdirectory to load the state from. Defaults to "". Raises: TypeError: If the object type is not supported. FileNotFoundError: If the file is not found. """ p = os.path.join(self.output_directory, path, filename) if not os.path.isfile(p): raise FileNotFoundError(f"File {p} not found.") _i = ( torch.nn.Module, torch.optim.Optimizer, torch.optim.lr_scheduler.LRScheduler, ) if not isinstance(obj, _i): raise TypeError( f"load_state of type {type(obj)} is not supported." ) with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=UserWarning) state_dict = torch.load(p, map_location="cpu", weights_only=True) obj.load_state_dict(state_dict)
[docs] def save_audobject( self, obj: audobject.Object, filename: str, path: str = "", ) -> None: """Save an audobject.Object to disk. Args: obj: Object to save. filename: Name of the file to save the object to. path: Subdirectory to save the object to. Defaults to "". Raises: TypeError: If the object type is not supported. """ if not isinstance(obj, audobject.Object): raise TypeError( f"save_audobject of type {type(obj)} is not supported." ) os.makedirs(os.path.join(self.output_directory, path), exist_ok=True) obj.to_yaml(os.path.join(self.output_directory, path, filename))
[docs] def save_results_dict( self, results_dict: Dict[str, float], filename: str, path: str = "", ) -> None: """Save a results dictionary to disk. Args: results_dict: Dictionary of metric names and values to save. filename: Name of the file to save the results to. path: Subdirectory to save the results to. Defaults to "". """ os.makedirs(os.path.join(self.output_directory, path), exist_ok=True) with open( os.path.join(self.output_directory, path, filename), "w", encoding="utf-8", ) as f: yaml.dump(results_dict, f)
[docs] def save_results_df( self, results_df: pd.DataFrame, filename: str, path: str = "" ) -> None: """Save a results DataFrame to disk. Args: results_df: DataFrame to save. filename: Name of the file to save the results to. path: Subdirectory to save the results to. Defaults to "". """ os.makedirs(os.path.join(self.output_directory, path), exist_ok=True) results_df.to_csv( os.path.join(self.output_directory, path, filename), index=False )
[docs] def save_results_np( self, results_np: np.ndarray, filename: str, path: str = "" ) -> None: """Save a results numpy array to disk. Args: results_np: Numpy array to save. filename: Name of the file to save the results to. path: Subdirectory to save the results to. Defaults to "". """ os.makedirs(os.path.join(self.output_directory, path), exist_ok=True) np.save( os.path.join(self.output_directory, path, filename), results_np )
[docs] def save_best_results( self, metrics: pd.DataFrame, filename: str, metric_fns: List[AbstractMetric], tracking_metric_fn: AbstractMetric, path: str = "", ) -> None: """Save the best results to disk. Args: metrics: DataFrame of metrics to save. filename: Name of the file to save the best results to. metric_fns: List of metric functions to get the best results from. tracking_metric_fn: Tracking metric function to get the best iteration from. path: Subdirectory to save the best results to. Defaults to "". """ best_metrics = {} for m in metrics: if "loss" in m: best_metrics[f"{m}_min"] = float(metrics[m].min()) for m in metric_fns: best_metrics[m.name] = m.get_best(metrics[m.name]) best_metrics["best_iteration"] = tracking_metric_fn.get_best_pos( metrics[tracking_metric_fn.name] ) os.makedirs(os.path.join(self.output_directory, path), exist_ok=True) with open( os.path.join(self.output_directory, path, filename), "w", encoding="utf-8", ) as f: yaml.dump(best_metrics, f)