Source code for mml.core.scripts.schedulers.upgrade_scheduler

# LICENSE HEADER MANAGED BY add-license-header
#
# SPDX-FileCopyrightText: Copyright 2024 German Cancer Research Center (DKFZ) and contributors.
# SPDX-License-Identifier: MIT
#

import logging
import tempfile
import warnings
from pathlib import Path
from typing import List, Tuple

import omegaconf.listconfig
import orjson
from omegaconf import DictConfig, OmegaConf
from rich.progress import track

import mml
from mml.core.data_loading.file_manager import TASK_PREFIX, MMLFileManager
from mml.core.data_loading.task_description import ALL_TASK_DESCRIPTION_KEYS
from mml.core.scripts.exceptions import MMLMisconfigurationException
from mml.core.scripts.schedulers.base_scheduler import AbstractBaseScheduler
from mml.core.scripts.utils import ask_confirmation

logger = logging.getLogger(__name__)


[docs] class UpgradeScheduler(AbstractBaseScheduler): """ AbstractBaseScheduler implementation for the Dataset and MML Upgrade process. Includes the following subroutines: - upgrade - downgrade For upgrade we always assume to upgrade to the currently installed version of MML. For downgrade, we assume that previously MML has been upgraded to the currently installed version and is now downgraded to the specified version in cfg.mode.version. """
[docs] def __init__(self, cfg: DictConfig): # make sure to create MMLFileManager beforehand to avoid RunTimeErrors during super.__init__ MMLFileManager( proj_path=Path(tempfile.mkdtemp()), data_path=Path(tempfile.mkdtemp()), log_path=Path(tempfile.mkdtemp()), reuse_cfg=None, # nothing to reuse here remove_cfg=None, # nothing to remove either ) # assert correct configuration if OmegaConf.is_missing(cfg.mode, "version"): raise MMLMisconfigurationException( "You are up/downgrading the mml setup without specifying a version! " "In case of upgrading please provide a source version, in case of " "downgrading please provide a target version. If in doubt, please " "read the documentation!" ) if ( not isinstance(cfg.mode.version, omegaconf.listconfig.ListConfig) and len(cfg.mode.version) == 3 and all([isinstance(elem, int) for elem in cfg.mode.version]) ): raise MMLMisconfigurationException("Specify source/target version as list of three int.") # tuple variant for better compatibility if isinstance(cfg.mode.version, list): self.version = tuple(cfg.mode.version) elif isinstance(cfg.mode.version, str): self.version = tuple([int(x) for x in cfg.mode.version.split(".")]) else: raise MMLMisconfigurationException("Provide mode.version either as list (aka [x,y,z]) or as str (x.y.z).") if len(self.version) != 3: raise MMLMisconfigurationException("Provide mode.version as major - minor - patch (list or str).") if "upgrade" in list(cfg.mode.subroutines) and "downgrade" in list(cfg.mode.subroutines): raise MMLMisconfigurationException("Upgrade mode may either be used to upgrade or to downgrade.") self.upgrading = "upgrade" in list(cfg.mode.subroutines) # initialize with warnings.catch_warnings(): warnings.simplefilter("ignore") super(UpgradeScheduler, self).__init__(cfg=cfg, available_subroutines=["upgrade", "downgrade"]) # since self.fm does not help with correct paths, we store the correct ones as well for the use in the scheduler self.data_path = Path(self.cfg["data_dir"])
[docs] def prepare_exp(self) -> None: """ Prepare experiment expects tasks to be present and loads these into task factory container. Here this should be avoided (including :meth:`after_preparation_hook`). """ logger.debug("Skipping experiment setup!")
[docs] def create_routine(self) -> None: """ This scheduler implements two subroutines, one for dataset preparation and one for task preparation. :return: None """ # determine necessary patches from dict of all available patches # pattern: Key=Version of chance - Value: function to up AND downgrade patches = {(0, 12, 0): self.upgrade_0_12} # sort and filter according to subroutine patch_ids: List[Tuple[int, int, int]] = sorted( filter(lambda x: self.version < x <= mml.VERSION, patches.keys()), reverse=not self.upgrading ) # -- add commands for patch_id in patch_ids: self.commands.append(patches[patch_id]) self.params.append([]) if len(patch_ids) == 0: logger.info("No patches necessary!") else: # ensure user is aware of implications msg = ( f"You are about to {'upgrade' if self.upgrading else 'downgrading'} your MML environment, " f"from version {self.version if self.upgrading else mml.VERSION} to version " f"{mml.VERSION if self.upgrading else self.version}. Although the effects should be revertible it " f"is recommended to create a backup of your data and/or results! Do you want to continue? " f'Please type "y"' ) try: confirmed = ask_confirmation(self.highlight_text(msg)) except TimeoutError: logger.error("No input provided for necessary response, will kill this run!") raise if not confirmed: raise RuntimeError('Stopped MML up/downgrade scheduler. To up/downgrade rerun and answer "y".')
[docs] def upgrade_0_12(self) -> None: """ This performs the necessary updates to results and data from 0.11 to 0.12 version of mml-core. Iterate over all installed tasks and update the keys: tags, train_tuples, unlabeled_tuples, test_tuples """ logger.info("Now rolling patch 0.12") all_task_descriptions = self._get_all_task_descriptions() logger.info(f"Found {len(all_task_descriptions)} task descriptions to update.") # processing each task description for description_path in track(all_task_descriptions, description="Updating task descriptions..."): # load with open(str(description_path), "rb") as f: data_dict = orjson.loads(f.read()) # replace (format is new : old) replacements = { "keywords": "tags", "train_samples": "train_tuples", "name": "alias", "unlabeled_samples": "unlabeled_tuples", "test_samples": "test_tuples", } # ensuring correct order for fast loading of header new_data_dict = {} # some keys got deprecated, will be added when downgrading (this causes information loss during upgrading) if not self.upgrading: for key in ["task_id", "orig_performance", "top_performance"]: new_data_dict[key] = ( description_path.parent.name + "%" + description_path.stem if key == "task_id" else "" ) # these are all remaining entries (sorted via the new definition) for key in ALL_TASK_DESCRIPTION_KEYS: # these keys have been renamed if key in replacements: current = replacements[key] if self.upgrading else key target = key if self.upgrading else replacements[key] else: current = target = key if current not in data_dict: raise RuntimeError(f"Did not find {current} in {description_path}.") new_data_dict[target] = data_dict[current] # write with open(str(description_path), "wb") as f: f.write(orjson.dumps(new_data_dict)) print(description_path) logger.info("Done rolling patch 0.12.")
def _get_all_task_descriptions(self) -> List[Path]: """ Helper function ro receive all task descriptions of the installation. :return: List of paths to all .json files defining task descriptions. """ # gather all TASK descriptions all_task_descriptions = [] for dataset in (self.data_path / "RAW").iterdir(): if not dataset.is_dir(): continue all_task_descriptions.extend(list(dataset.glob("".join(["[" + c + "]" for c in TASK_PREFIX]) + "*.json"))) # now preprocessed folder for preprocess in (self.data_path / "PREPROCESSED").iterdir(): if not preprocess.is_dir(): continue for dataset in preprocess.iterdir(): if not dataset.is_dir(): continue all_task_descriptions.extend( list(dataset.glob("".join(["[" + c + "]" for c in TASK_PREFIX]) + "*.json")) ) return all_task_descriptions