Source code for autrainer.loggers.mlflow_logger

import os
from pathlib import Path
from typing import Dict, List, Union
import warnings

from omegaconf import DictConfig

from autrainer.core.constants import ExportConstants
from autrainer.metrics import AbstractMetric

from .abstract_logger import (
    AbstractLogger,
    get_params_to_export,
)
from .fallback_logger import FallbackLogger


try:
    with warnings.catch_warnings():
        warnings.filterwarnings("ignore", category=DeprecationWarning)
        import mlflow

    MLFLOW_AVAILABLE = True
except ImportError:  # pragma: no cover
    MLFLOW_AVAILABLE = False  # pragma: no cover


[docs] class MLFlowLogger(AbstractLogger): def __init__( self, exp_name: str, run_name: str, metrics: List[AbstractMetric], tracking_metric: AbstractMetric, artifacts: List[ Union[str, Dict[str, str]] ] = ExportConstants().ARTIFACTS, output_dir: str = "mlruns", ) -> None: super().__init__( exp_name, run_name, metrics, tracking_metric, artifacts ) output_dir = Path(output_dir) if not output_dir.is_absolute(): output_dir = output_dir.absolute() if not any( str(output_dir).startswith(prefix) for prefix in ["file://", "http://", "https://"] ): output_dir = output_dir.as_uri() self.output_dir = output_dir def setup(self) -> None: mlflow.set_tracking_uri(self.output_dir) self.exp_id = self._get_or_create_experiment() mlflow.set_experiment(experiment_id=self.exp_id) self.run = self._get_or_create_run() def _get_or_create_experiment(self) -> str: experiment = mlflow.get_experiment_by_name(self.exp_name) if experiment: return experiment.experiment_id return mlflow.create_experiment(name=self.exp_name) def _get_or_create_run(self) -> "mlflow.ActiveRun": self._delete_run_if_exists(self.run_name) run = mlflow.start_run(run_name=self.run_name) return run def _delete_run_if_exists(self, run_name: str) -> None: client = mlflow.MlflowClient() runs = mlflow.search_runs( experiment_ids=[self.exp_id], filter_string=f"tags.mlflow.runName='{run_name}'", ) if runs.shape[0] > 0: run_id = runs.iloc[0]["run_id"] client.delete_run(run_id) def log_params(self, params: Union[dict, DictConfig]) -> None: params = get_params_to_export(params) mlflow.log_params(params) def log_metrics( self, metrics: Dict[str, Union[int, float]], iteration=None, ) -> None: mlflow.log_metrics(metrics, step=iteration) def log_timers(self, timers: Dict[str, float]) -> None: mlflow.log_params(timers) def log_artifact(self, filename: str, path: str = "") -> None: mlflow.log_artifact(os.path.join(path, filename)) def end_run(self) -> None: mlflow.end_run()
MLFlowLogger = ( MLFlowLogger if MLFLOW_AVAILABLE else lambda *args, **kwargs: FallbackLogger( "MLFlowLogger", "mlflow" ) # pragma: no cover )