Source code for mml.core.scripts.utils

# LICENSE HEADER MANAGED BY add-license-header
#
# SPDX-FileCopyrightText: Copyright 2024 German Cancer Research Center (DKFZ) and contributors.
# SPDX-License-Identifier: MIT
#

import logging
import os
import sys
import time
from enum import Enum
from pathlib import Path
from types import TracebackType
from typing import Any, ClassVar, Dict, List, Optional, Type, TypeVar

import p_tqdm.p_tqdm as p_tqdm
from dotenv import load_dotenv
from pathos.multiprocessing import ProcessPool
from pathos.threading import ThreadPool

import mml
from mml.core.scripts.decorators import timeout
from mml.core.scripts.exceptions import MMLMisconfigurationException

T = TypeVar("T")
TSingleton = TypeVar("TSingleton", bound="Singleton")

__all__ = [
    "catch_time",
    "throttle_logging",
    "load_env",
    "multi_threaded_p_tqdm",
    "Singleton",
    "LearningPhase",
    "load_mml_plugins",
    "TAG_SEP",
    "ARG_SEP",
    "ask_confirmation",
]

logger = logging.getLogger(__name__)

TAG_SEP = "+"
ARG_SEP = "?"
# provides information on loaded plugins
MML_PLUGINS_LOADED = {}


[docs] class catch_time: """ Timing utility context manager. Usage: .. code-block: python with catch_time() as timer: # some code access time via `timer.pretty_time` afterward, e.g. for logging. """ def __enter__(self) -> "catch_time": self.elapsed = time.monotonic() # self.time = datetime.datetime.now().replace(microsecond=0) return self def __exit__( self, exc_type: Optional[Type[BaseException]], exc: Optional[BaseException], traceback: Optional[TracebackType] ) -> None: self.elapsed = time.monotonic() - self.elapsed self.hours, rest = divmod(self.elapsed, 3600) self.minutes, self.seconds = divmod(rest, 60) self.pretty_time = f"{self.hours}h {self.minutes}m {self.seconds:5.2f}s"
[docs] class throttle_logging: """ Logging utility context manager. Usage: .. code-block: python with throttle_logging(logging.SOME_LEVEL, (optional) package): # some code that will only propagate logging above (excluding) specified level (of package if given) afterwards logging continues as before. The context manager checks if the root logger is in DEBUG mode and prevents throttling in that case. """
[docs] def __init__(self, level: int = logging.INFO, package: Optional[str] = None): self.level = level self.logger = logging.getLogger(package) if package else None self.stored_level = self.logger.level if self.logger else None
def __enter__(self) -> "throttle_logging": if logging.root.level > logging.DEBUG: if self.logger: self.logger.setLevel(self.level + 1) else: logging.disable(self.level) return self def __exit__( self, exc_type: Optional[Type[BaseException]], exc: Optional[BaseException], traceback: Optional[TracebackType] ) -> None: if self.logger: self.logger.setLevel(self.stored_level) # type: ignore else: logging.disable(logging.NOTSET)
[docs] def load_env() -> None: """ Loads the `mml.env` variables. Make sure to have renamed and adapted `example.env` beforehand. If an environment variable `MML_ENV_PATH` is given this file is preferred, else the default path inside the `mml` package is used. :return: None """ if os.getenv("MML_ENV_PATH", None): dotenv_path = Path(os.getenv("MML_ENV_PATH")) logger.debug(f"MML_ENV_PATH provided, will try to load env variables from {dotenv_path}.") else: dotenv_path = Path(mml.__file__).parent / "mml.env" logger.debug(f"No MML_ENV_PATH provided, will try to load env variables from default path ({dotenv_path}).") if dotenv_path.exists(): load_dotenv(dotenv_path=dotenv_path) # as long as the default system (=local) is used, check for the minimum env variables if not any(cli_arg.startswith("sys=") for cli_arg in sys.argv): if not Path(os.getenv("MML_DATA_PATH")).exists(): raise MMLMisconfigurationException("Invalid MML_DATA_PATH, have you modified the mml.env entry?") if not Path(os.getenv("MML_RESULTS_PATH")).exists(): raise MMLMisconfigurationException("Invalid MML_RESULTS_PATH, have you modified the mml.env entry?") try: _ = int(os.getenv("MML_LOCAL_WORKERS")) except ValueError: raise MMLMisconfigurationException("Invalid MML_LOCAL_WORKERS, have you modified the mml.env entry?") else: raise MMLMisconfigurationException( f".env file not found at {dotenv_path}! Please follow the documentation instructions to setup MML." )
[docs] class multi_threaded_p_tqdm: """ Switches the internally used pool type of p_tqdm package from ProcessPool to ThreadPool. """ def __enter__(self) -> None: p_tqdm.Pool = ThreadPool def __exit__( self, exc_type: Optional[Type[BaseException]], exc: Optional[BaseException], traceback: Optional[TracebackType] ) -> None: p_tqdm.Pool = ProcessPool
class _SingletonMeta(type): """ This is a helper Metaclass to implement the Singleton class. """ _instances: ClassVar[Dict[Type[T], T]] = {} def __call__(cls: Type[T], *args: Any, **kwargs: Any) -> T: if cls not in cls._instances: cls._instances[cls] = super(_SingletonMeta, cls).__call__(*args, **kwargs) return cls._instances[cls]
[docs] class Singleton(metaclass=_SingletonMeta): """ The actual Singleton class to inherit from. Make sure to have Singleton as the leftmost base class. """
[docs] @classmethod def clear_instance(cls) -> None: """ Clears the cached instance of the singleton. Be aware, that this does not affect references to the "old" instance, but any further call of "instance" or Class() will create a new instance, that will be returned from any further call. :return: """ cls._instances.pop(cls)
[docs] @classmethod def instance(cls: Type[TSingleton], *args: Any, **kwargs: Any) -> TSingleton: """ Convenience function that does the same as Class(), but makes the singleton property more readable in code. :param args: any init args, be aware that these are ignored if there already exists an instance :param kwargs: any init kwargs, be aware that these are ignored if there already exists an instance :return: either a new instance (first call) or a reference to the already existing instance """ return cls.__call__(*args, **kwargs)
[docs] @classmethod def exists(cls) -> bool: return cls in cls._instances
class StrEnum(str, Enum): """ Type of any enumerator with allowed comparison to string invariant to cases. Adopted from :class:`~pytorch_lightning.utilities.enums.LightningEnum`. """ @classmethod def from_str(cls, value: str) -> Optional["StrEnum"]: for enum_key, enum_val in cls.__members__.items(): if enum_val.lower() == value.lower() or enum_key.lower() == value.lower(): return cls[enum_key] raise ValueError(f"No match found for value {value} in enum {cls.__name__}") def __str__(self) -> str: return self.value.lower() def __eq__(self, other: object) -> bool: other = other.value if isinstance(other, Enum) else str(other) return self.value.lower() == other.lower() def __hash__(self) -> int: # re-enable hashtable so it can be used as a dict key or in a set return hash(self.value.lower()) @classmethod def list(cls) -> List[str]: """ Lists all members of a StrEnum class. """ return list(map(lambda c: c.value, cls))
[docs] class LearningPhase(StrEnum): TRAIN = "train" VAL = "val" TEST = "test"
[docs] @staticmethod def all_phases() -> List["LearningPhase"]: return [LearningPhase.TRAIN, LearningPhase.VAL, LearningPhase.TEST]
[docs] def load_mml_plugins() -> None: """ This function allows to load mml plugins. These are other installed packages that provide a 'mml.plugins' entry point. See https://packaging.python.org/en/latest/guides/creating-and-discovering-plugins/#using-package-metadata for details on this mechanism. """ if sys.version_info < (3, 10): from importlib_metadata import entry_points, version else: from importlib.metadata import entry_points, version # load registered plugins discovered_plugins = entry_points(group="mml.plugins") if len(discovered_plugins) > 0: # usually logger is not yet configured by hydra, so plugins are also stored in global variable later on logger.info(f"Discovered plugins: {[p.name for p in discovered_plugins]}!") for plugin in discovered_plugins: logger.debug(f"Loading plugin {plugin.name}.") _ = plugin.load() logger.debug(f"Successfully loaded plugin {plugin.name}.") global MML_PLUGINS_LOADED MML_PLUGINS_LOADED.update({p.name: version(p.module.split(".")[0]) for p in discovered_plugins})
[docs] @timeout(seconds=60) def ask_confirmation(message: str = "") -> bool: """ Lets user confirm a message. :param message: The message to be confirmed :return: bool indicating whether the message has been confirmed :rtype: bool """ print(message) response = input(">>") if response.lower().strip() != "y": return False else: return True