Source code for mml.interactive.loading

# LICENSE HEADER MANAGED BY add-license-header
#
# SPDX-FileCopyrightText: Copyright 2024 German Cancer Research Center (DKFZ) and contributors.
# SPDX-License-Identifier: MIT
#

import contextlib
import os
from pathlib import Path
from typing import Dict, Generator, Iterable, List, Optional, Sequence, Union

import omegaconf
from omegaconf import DictConfig

from mml.core.data_loading.file_manager import MMLFileManager, ReuseConfig
from mml.core.data_loading.task_struct import TaskStruct, TaskStructFactory
from mml.core.scripts.model_storage import ModelStorage
from mml.interactive import _check_init


[docs] def load_project_models(project: str) -> Dict[str, List[ModelStorage]]: """ Loading utility to get all models of a given project. :param str project: name of the project, what has been inserted with 'proj=...' :return: dict with task name keys and a list of all corresponding ModelStorages """ _check_init() proj_path = Path(f"{os.getenv('MML_RESULTS_PATH')}/{project}") tmp_log_path = proj_path / "tmp_log" tmp_log_path.mkdir(exist_ok=True) r_conf = ReuseConfig(models=f"{project}") fm = MMLFileManager( data_path=Path(f"{os.getenv('MML_DATA_PATH')}"), proj_path=proj_path, log_path=tmp_log_path, reuse_cfg=r_conf ) all_reusables = fm.reusables MMLFileManager.clear_instance() return {task_name: task_reusables["models"] for task_name, task_reusables in all_reusables.items()}
[docs] def merge_project_models(project_models_list: Iterable[Dict[str, List[ModelStorage]]]) -> Dict[str, List[ModelStorage]]: """ Merges models loaded from multiple projects. :param project_models_list: list of dicts, as returned by multiple calls from func::load_project_models :return: merged list, as if all models were trained in one single project """ out = {} for project_models in project_models_list: for task, model_list in project_models.items(): if task in out: out[task].extend(model_list) else: out[task] = model_list return out
[docs] @contextlib.contextmanager def default_file_manager( reuse_config: Optional[Union[DictConfig, ReuseConfig]] = None, ) -> Generator[MMLFileManager, None, None]: """ Convenience method to get a MMLFileManager instance. To be used in a with statement: .. code-block:: python with default_file_manager() as fm: fm.do_something (e.g. extract information) ... continue code with extracted information (without fm) :return: """ _check_init() proj_path = Path(f"{os.getenv('MML_RESULTS_PATH')}/default") proj_path.mkdir(exist_ok=True) tmp_log_path = proj_path / "tmp_log" tmp_log_path.mkdir(exist_ok=True) if reuse_config is None: reuse_config = ReuseConfig() try: yield MMLFileManager( data_path=Path(f"{os.getenv('MML_DATA_PATH')}"), proj_path=proj_path, log_path=tmp_log_path, reuse_cfg=reuse_config, ) finally: MMLFileManager.clear_instance()
[docs] def get_task_structs(tasks: Union[str, Sequence[str]], preprocessing: str = "default") -> List[TaskStruct]: """ Create a task struct on the fly. :param str tasks: task name or sequence of task names :param str preprocessing: the preprocessing id of the task (default: 'default') :return: the corresponding task struct :rtype: TaskStruct """ _check_init() cfg = omegaconf.OmegaConf.create({"preprocessing": {"id": preprocessing}}) if isinstance(tasks, str): tasks = [tasks] structs = [] with default_file_manager(): factory = TaskStructFactory(cfg=cfg) for task in tasks: structs.append(factory.create_task_struct(name=task, return_ref=True)) return structs