Source code for aucurriculum.core.scripts.train_script

from typing import Optional

import autrainer
from autrainer.core.scripts import TrainScript
from autrainer.core.scripts.abstract_script import MockParser
from autrainer.core.scripts.utils import (
    add_hydra_args_to_sys,
    catch_cli_errors,
    run_hydra_cmd,
    running_in_notebook,
)
from omegaconf import DictConfig, OmegaConf


class TrainCurriculumScript(TrainScript):
    def main(self, args: dict) -> None:
        @autrainer.main("config")
        def main(cfg: DictConfig) -> float:
            import os

            import hydra

            OmegaConf.set_struct(cfg, False)
            OmegaConf.resolve(cfg)
            output_dir = (
                hydra.core.hydra_config.HydraConfig.get().runtime.output_dir
            )

            # ? Skip if run exists and return best tracking metric
            if os.path.exists(os.path.join(output_dir, "metrics.csv")):
                import autrainer
                from autrainer.core.utils import Bookkeeping
                from autrainer.metrics import AbstractMetric

                dev_metrics = OmegaConf.load(
                    os.path.join(output_dir, "_best", "dev.yaml")
                )
                tracking_metric = autrainer.instantiate_shorthand(
                    config=cfg.dataset.tracking_metric,
                    instance_of=AbstractMetric,
                )
                best_metric = dev_metrics[tracking_metric.name]["all"]

                bookkeeping = Bookkeeping(output_dir)
                bookkeeping.log(f"Skipping: {os.path.basename(output_dir)}")

                return best_metric

            # ? Save cfg to output directory
            cfg_path = os.path.join(output_dir, ".hydra", "config.yaml")
            os.makedirs(os.path.dirname(cfg_path), exist_ok=True)
            OmegaConf.save(cfg, cfg_path)

            from autrainer.training import ModularTaskTrainer

            if cfg.curriculum.id != "None":
                cbs = cfg.get("callbacks", [])
                cbs.append("aucurriculum.curricula.CurriculumPaceManager")
                cfg.callbacks = cbs

            trainer = ModularTaskTrainer(
                cfg=cfg,
                output_directory=output_dir,
            )
            return trainer.train()

        main()


[docs] @catch_cli_errors def train( override_kwargs: Optional[dict] = None, config_name: str = "config", config_path: Optional[str] = None, ) -> None: """Launch a training configuration. Args: override_kwargs: Additional Hydra override arguments to pass to the train script. config_name: The name of the config (usually the file name without the .yaml extension). Defaults to "config". config_path: The config path, a directory where Hydra will search for config files. If config_path is None no directory is added to the search path. Defaults to None. """ if running_in_notebook(): run_hydra_cmd( "train", override_kwargs, config_name, config_path, cmd_prefix="aucurriculum", ) else: add_hydra_args_to_sys(override_kwargs, config_name, config_path) script = TrainCurriculumScript() script.parser = MockParser() script.main({})