Source code for mml.core.scripts.pipeline_configuration

# LICENSE HEADER MANAGED BY add-license-header
#
# SPDX-FileCopyrightText: Copyright 2024 German Cancer Research Center (DKFZ) and contributors.
# SPDX-License-Identifier: MIT
#

import logging
import warnings
from contextlib import contextmanager
from copy import deepcopy
from pathlib import Path
from typing import List, Optional

from omegaconf import DictConfig, OmegaConf

from mml.core.data_loading.file_manager import MMLFileManager
from mml.core.data_loading.task_struct import TaskStruct

logger = logging.getLogger(__name__)


[docs] class PipelineCfg:
[docs] def __init__(self, pipeline_cfg: DictConfig, restrict_keys: Optional[List[str]] = None) -> None: """ PipelineCfg holds relevant configuration elements of a training pipeline to store, reproduce and leverage knowledge at a later point. The intended usage is by invoking :meth:`from_cfg` on the full mml config, which will produce a masked copy only focussing on a subset of config keys. :param DictConfig pipeline_cfg: a config :param Optional[List[str]] restrict_keys: which config keys to focus upon """ self.pipeline_cfg = pipeline_cfg self.pipeline_keys = [] restrict_keys = PIPELINE_CONFIG_PARTS if restrict_keys is None else restrict_keys for key in restrict_keys: if not isinstance(key, str): raise ValueError("provide pipeline keys as strings") if not hasattr(self.pipeline_cfg, key): warnings.warn(f"requested key {key} not found in pipeline_cfg, will be ignored") continue self.pipeline_keys.append(key) if len(self.pipeline_keys) == 0: raise ValueError("(valid) pipeline keys are empty") # reduce pipeline_config, this is a fallback if called directly upon a full config, but also in case at some # point only a subset of an existing pipeline configuration is intended to be reused self.pipeline_cfg = OmegaConf.masked_copy(self.pipeline_cfg, keys=self.pipeline_keys)
[docs] @classmethod def from_cfg(cls, current_cfg: DictConfig, restrict_keys: Optional[List[str]] = None) -> "PipelineCfg": """ Extracts relevant pipeline keys from current config determined by restrict_keys. :param DictConfig current_cfg: the FULL config to derive the pipeline configuration from :param Optional[List[str]] restrict_keys: which config keys to focus upon :return: """ pipeline_keys = PIPELINE_CONFIG_PARTS if restrict_keys is None else restrict_keys if not all([isinstance(key, str) and hasattr(current_cfg, key) for key in pipeline_keys]): raise ValueError(f"keys {pipeline_keys} contains a value, which might be not present in the current config") return cls(pipeline_cfg=OmegaConf.masked_copy(current_cfg, keys=pipeline_keys), restrict_keys=pipeline_keys)
[docs] @contextmanager def activate(self, current_cfg: DictConfig) -> None: """ To be used as a config manager, activates this pipeline upon the currently active mml config. When the context exits, the original configuration is restored. :param DictConfig current_cfg: the currently active mml config :return: no return value, the mml config is modified in place """ # create backup for later restoration old = deepcopy(current_cfg) # set config elements based on keys for key in self.pipeline_keys: if key in self.pipeline_cfg: OmegaConf.update(current_cfg, key=key, value=self.pipeline_cfg[key], merge=False, force_add=False) logger.debug(f"Activated key {key} from pipeline configuration.") # yield to do training etc. yield # restore old configuration for key in self.pipeline_keys: OmegaConf.update(current_cfg, key=key, value=old[key], merge=False, force_add=False) logger.debug("Deactivated pipeline configuration.")
[docs] def store(self, task_struct: TaskStruct, as_blueprint: bool = False) -> Path: """ Store this pipeline. Requires a task struct to determine task name. If blueprint is set, this will be stored in the BLUEPRINTS folder instead of PIPELINES. This allows for easier re-usage. :param TaskStruct task_struct: struct of the task this pipeline has or should be applied upon :param bool as_blueprint: if true store as blueprint otherwise as pipeline :return: the path to the stored file """ # stores pipeline_cfg key = "blueprint" if as_blueprint else "pipeline" path = MMLFileManager.instance().construct_saving_path( obj=self.pipeline_cfg, key=key, task_name=task_struct.name ) # do not resolve to not overwrite OmegaConf.save(config=self.pipeline_cfg, f=path, resolve=False) return path
[docs] @classmethod def load(cls, path: Path, pipeline_keys: Optional[List[str]] = None) -> "PipelineCfg": """ Load a stores pipeline configuration (or blueprint) from path. :param Path path: the path to the stored file :param Optional[List[str]] pipeline_keys: which config keys to focus upon :return: the loaded pipeline configuration (restricted to provided pipeline keys) """ loaded = OmegaConf.load(path) logger.debug(f"Loaded pipeline from {path}.") return cls(pipeline_cfg=loaded, restrict_keys=pipeline_keys)
[docs] def clone(self) -> "PipelineCfg": """ Convenience method to deepcopy the configuration. :return: a deepcopy of the configuration """ return deepcopy(self)
# these are the default keys if none are specified, they comprise all relevant aspects for reproducing a model training PIPELINE_CONFIG_PARTS = [ "arch", "augmentations", "cbs", "loss", "lr_scheduler", "mode", "optimizer", "preprocessing", "sampling", "trainer", "tta", "tune", ]