# LICENSE HEADER MANAGED BY add-license-header
#
# SPDX-FileCopyrightText: Copyright 2024 German Cancer Research Center (DKFZ) and contributors.
# SPDX-License-Identifier: MIT
#
import dataclasses
import logging
import math
import os
import subprocess
import sys
from pathlib import Path
from typing import Dict, List, Optional, Sequence, Set, Union
import numpy as np
import pandas as pd
import scipy.stats
from omegaconf import OmegaConf
from mml.cli import main
from mml.core.data_loading.task_attributes import Keyword, TaskType
from mml.core.data_preparation.registry import get_dset_for_task
from mml.core.scripts.utils import TAG_SEP
from mml.interactive import _check_init
from mml.interactive.loading import default_file_manager
logger = logging.getLogger(__name__)
[docs]
@dataclasses.dataclass
class AllTasksInfos:
"""
A class to store all standard meta information on a set of tasks.
"""
num_classes: Dict[str, int]
num_samples: Dict[str, int]
imbalance_ratios: Dict[str, float]
datasets: Dict[str, str]
keywords: Dict[str, Set[Keyword]]
task_types: Dict[str, TaskType]
domains: Dict[str, Keyword]
dimensions: Dict[str, int]
max_resolution: Dict[str, int]
min_resolution: Dict[str, int]
small_tasks: List[str]
medium_tasks: List[str]
large_tasks: List[str]
def __repr__(self):
return f"AllTasksInfos({len(self.num_classes)} tasks: {list(self.num_classes.keys())})"
def __post_init__(self):
"""
Automatically called after instantiation of an instance. Performs consistency check.
:return:
"""
self.check_consistency()
[docs]
def check_consistency(self):
"""
Performs assertions that all information cover the same set of tasks.
:return:
"""
assert (
self.num_classes.keys()
== self.num_samples.keys()
== self.keywords.keys()
== self.max_resolution.keys()
== self.min_resolution.keys()
)
assert (self.dimensions.keys() == self.num_classes.keys()) or len(self.dimensions) == 0
# assert set(self.small_tasks) != set(self.medium_tasks) != set(self.large_tasks) != set(self.small_tasks)
assert set(self.num_classes) == set(self.small_tasks + self.medium_tasks + self.large_tasks)
[docs]
def store_csv(self, path: Path) -> None:
"""
Reformat meta information and write as a csv file.
:param Path path: path to store csv file
:return: None
"""
all_tasks = sorted(self.num_samples.keys())
task_infos = []
for t in all_tasks:
info_dict = {"name": t}
for attr in [
"task_types",
"num_classes",
"num_samples",
"domains",
"imbalance_ratios",
"datasets",
"keywords",
"dimensions",
"max_resolution",
"min_resolution",
]:
attr_dict = getattr(self, attr)
try:
elem = attr_dict[t]
if attr == "keywords": # Set[Keyword]
elem = [str(_kw) for _kw in sorted(elem)]
elif attr == "task_types" or attr == "domains": # TaskType or Keyword
elem = str(elem)
info_dict[attr] = elem
except KeyError:
# some entries might be missing
pass
task_infos.append(info_dict)
df = pd.DataFrame(task_infos)
df.to_csv(path)
[docs]
@classmethod
def from_csv(cls, path: Path) -> "AllTasksInfos":
"""
Load stored AllTasksInfos from a csv file.
:param path: path to load csv file
:return: AllTasksInfos
"""
df = pd.read_csv(path).set_index("name")
kwargs = {}
for attr in [
"task_types",
"num_classes",
"num_samples",
"domains",
"imbalance_ratios",
"datasets",
"keywords",
"dimensions",
"max_resolution",
"min_resolution",
]:
kwargs[attr] = df[attr].to_dict() if attr in df.columns else {}
if attr == "keywords": # Dict[str, Set[Keyword]]
for task in kwargs[attr]:
if kwargs[attr][task]:
kwargs[attr][task] = set(
[Keyword.from_str(_kw.strip(" \[\]'")) for _kw in kwargs[attr][task].split(",")]
)
elif attr == "task_types": # Dict[str, TaskType]
kwargs[attr] = {task: TaskType(entry) for task, entry in kwargs[attr].items() if isinstance(entry, str)}
elif attr == "domains": # Dict[str, Keyword]
kwargs[attr] = {task: Keyword(entry) for task, entry in kwargs[attr].items() if isinstance(entry, str)}
small_tasks = [t for t, size in kwargs["num_samples"].items() if size < 1000]
medium_tasks = [t for t, size in kwargs["num_samples"].items() if 1000 <= size < 10000]
large_tasks = [t for t, size in kwargs["num_samples"].items() if size >= 10000]
all_infos = cls(small_tasks=small_tasks, medium_tasks=medium_tasks, large_tasks=large_tasks, **kwargs)
return all_infos
[docs]
def get_task_infos(task_list: List[str], dims: Optional[str] = None) -> AllTasksInfos:
"""
Most convenient way to receive a :class:AllTasksInfos instance. Provide a list of aliases and optional a project
name that computed dimensions before.
:param List[str] task_list: list of task names, tasks must be available on the machine (run create before if
not)
:param Optional[str] dims: (optional) project name that computed dimensions with mml dim proj=THIS_ARG
:return: relevant meta information on all tasks combined in one object
:rtype: AllTasksInfos
"""
_check_init()
if dims:
r_conf = OmegaConf.create({"dimension": dims})
try:
from mml_dimensionality.scripts.utils import load_dim
except ImportError:
raise ImportError("Install mml-dimensionality to use the dim functionality of get_task_infos.")
else:
load_dim = None
r_conf = OmegaConf.create({})
# initialize empty dicts
num_classes = {}
num_samples = {}
imbalance_ratios = {}
keywords = {}
domains = {}
dimensions = {}
task_types = {}
max_resolution = {}
min_resolution = {}
domain_list = [
Keyword.DERMATOSCOPY,
Keyword.LARYNGOSCOPY,
Keyword.GASTROSCOPY_COLONOSCOPY,
Keyword.LAPAROSCOPY,
Keyword.NATURAL_OBJECTS,
Keyword.HANDWRITINGS,
Keyword.CATARACT_SURGERY,
Keyword.FUNDUS_PHOTOGRAPHY,
Keyword.MRI_SCAN,
Keyword.X_RAY,
Keyword.CT_SCAN,
Keyword.CLE,
Keyword.CAPSULE_ENDOSCOPY,
Keyword.ULTRASOUND,
]
overarching_dataset = {}
with default_file_manager(reuse_config=r_conf) as fm:
for task in task_list:
meta = fm.get_task_info(task, preprocess="none")
num_classes[task] = len(set(meta["idx_to_class"].values()))
task_desc = fm.load_task_description(fm.data_path / meta["relative_root"])
num_samples[task] = task_desc.num_samples
try:
imbalance_ratios[task] = max(meta["class_occ"].values()) / min(meta["class_occ"].values())
except ZeroDivisionError:
logger.error(f"Division by zero while computing imbalance ratio for task {task}.")
except ValueError:
logger.error(f"Unable to compute imbalance ratio for task {task}. Class occ = {meta['class_occ']}")
keywords[task] = set(meta["keywords"])
task_types[task] = meta["task_type"]
sizes = meta["sizes"]
min_resolution[task] = sizes.min_height * sizes.min_width
max_resolution[task] = sizes.max_height * sizes.max_width
domain_candidates = [d for d in keywords[task] if d in domain_list]
if len(domain_candidates) != 1:
logger.error(f"Error while searching domain for task {task}, have candidates {domain_candidates}")
else:
domains[task] = domain_candidates[0].value
try:
plain_task = task.split(TAG_SEP)[0]
overarching_dataset[task] = get_dset_for_task(task_name=plain_task)
except KeyError:
logger.error(f"No registered dataset for task {task}.")
if dims:
dimensions[task] = load_dim(fm.reusables[task]["dimension"])
small_tasks = [t for t, size in num_samples.items() if size < 1000]
medium_tasks = [t for t, size in num_samples.items() if 1000 <= size < 10000]
large_tasks = [t for t, size in num_samples.items() if size >= 10000]
all_infos = AllTasksInfos(
num_classes=num_classes,
num_samples=num_samples,
imbalance_ratios=imbalance_ratios,
keywords=keywords,
domains=domains,
dimensions=dimensions,
max_resolution=max_resolution,
min_resolution=min_resolution,
small_tasks=small_tasks,
medium_tasks=medium_tasks,
large_tasks=large_tasks,
datasets=overarching_dataset,
task_types=task_types,
)
return all_infos
# Helper functions for rendering and running mml calls in different scenarios
[docs]
class JobPrefixRequirements:
"""The job prefix requirements to a job. Basically resolves how to invoke mml on the system."""
[docs]
def get_prefix(self) -> str: ...
[docs]
class DefaultRequirements(JobPrefixRequirements):
"""The default how to call MML from e.g. a local machine (assuming it to be installed and the environment to
be loaded."""
[docs]
def get_prefix(self) -> str:
return "mml"
[docs]
@dataclasses.dataclass
class MMLJobDescription:
"""
Combined description of an MML call. Includes prefix requirements, config options and a multirun flag for hpo.
"""
prefix_req: JobPrefixRequirements
mode: str
config_options: Dict[str, Union[str, float, List[Union[str, int, float]], int]]
multirun: bool = False
[docs]
def render(self) -> str:
"""
Actually renders the job description.
:return: A string that might be pasted into a terminal to start the job described.
"""
parts = [self.prefix_req.get_prefix(), self.mode]
for key, option in self.config_options.items():
if isinstance(option, str) and " " in option:
raise ValueError("Found whitespace in JobDescription!")
# no check for interrupting str in list
if isinstance(option, list):
option = str(option).replace(" ", "")
if key == "mode":
raise ValueError(
"Providing mode inside config_options is not supported anymore, provide mode "
"directly to MMLJobDescription."
)
parts.append(f"{key}={option}")
if self.multirun:
parts.append("--multirun")
return " ".join(parts)
[docs]
def run(self, runner: "JobRunner") -> Optional[float]:
"""
Runs the job with the given runner.
:param JobRunner runner: the runner to run the job.
:return: Potentially a float that represents the return value of the specified experiment (not guaranteed)
"""
return runner.run(job=self)
[docs]
class JobRunner:
"""The runner that invokes the rendered MML call."""
[docs]
def run(self, job: MMLJobDescription): ...
[docs]
class EmbeddedJobRunner(JobRunner):
"""
The embedded runner allows to start mml directly from within the same python interpreter, hence any previous
variables, imports, etc. are available during runtime. This also allows to receive the return value of MML.
"""
[docs]
def run(self, job: MMLJobDescription):
# override sys.argv
sys.argv = job.render().split(" ")
return main()
[docs]
class SubprocessJobRunner(JobRunner):
"""
The subprocess runner only inherits the virtual environment but starts a new process including a new interpreter.
Any variables in the current interpreter will not be available during this run. It does not receive any return
values of an experiment.
"""
[docs]
def run(self, job: MMLJobDescription):
subprocess.run(job.render().split(" "))
# convenience function for producing long outputs
[docs]
def write_out_commands(
cmd_list: List[MMLJobDescription],
name: str = "output",
seperator: Optional[str] = "sleep 2\n",
max_cmds: Optional[int] = None,
) -> None:
"""
Writes a list of :class:MMLJobDescription into a file that may be called by a shell afterward. This is particularly
useful if the commands should be transferred to a different host via ssh, e.g. with::
ssh user@host 'bash -s' < /path/to/output.txt
:param List[MMLJobDescription] cmd_list: list of commands
:param Optional[str] name: a file name to relate cmds to a common project or experiment, defaults to 'output'
:param Optional[str] seperator: (optional) a line seperator, useful if e.g. sleep X should delay cmd submission to
a cluster
:param Optional[int] max_cmds: (optional) max number of cmds per file, will split into consecutive files if more
cmds are present
:return:
"""
if seperator is None:
seperator = ""
else:
if len(seperator) > 0 and seperator[-1] != "\n":
raise ValueError(
"Seperator does not end line, this may cause interference with cmds and should be avoided!"
)
if max_cmds is None:
max_cmds = len(cmd_list) + 1
num_splits = math.ceil(len(cmd_list) / max_cmds)
for split_idx in range(num_splits):
file_name = name
if num_splits > 1:
file_name += f"_{split_idx}"
file_name += ".txt"
out = ""
for cmd in cmd_list[split_idx * max_cmds : (split_idx + 1) * max_cmds]:
out = out + cmd.render() + "\n" + seperator
with open(Path(os.path.abspath("")) / file_name, "w") as file:
file.write(out)
print(f"Stored {len(cmd_list[split_idx * max_cmds : (split_idx + 1) * max_cmds])} commands at {file_name}.")