# LICENSE HEADER MANAGED BY add-license-header
#
# SPDX-FileCopyrightText: Copyright 2024 German Cancer Research Center (DKFZ) and contributors.
# SPDX-License-Identifier: MIT
#
import functools
import logging
import warnings
from pathlib import Path
from typing import Callable, Dict, List, Optional, Tuple, Union
import orjson
from omegaconf import DictConfig
from mml.core.data_loading.file_manager import MMLFileManager
from mml.core.data_loading.task_attributes import Keyword, Modality, RGBInfo, Sizes, TaskType
from mml.core.data_preparation.task_creator import TaskCreator
from mml.core.scripts.decorators import deprecated
from mml.core.scripts.exceptions import TaskNotFoundError
from mml.core.scripts.model_storage import ModelStorage
from mml.core.scripts.utils import TAG_SEP, catch_time
logger = logging.getLogger(__name__)
[docs]
class TaskStruct:
"""
Object to handle tasks on a meta level in the framework. Contains basic information on location of data and
links to intermediate and final results. Will be instantiated by the TaskFactory. During runtime results
corresponding to the dataset will also be stored within the object (such as trained models, calculated FIM and
performance).
"""
[docs]
def __init__(
self,
name: str,
task_type: TaskType,
means: RGBInfo,
stds: RGBInfo,
sizes: Sizes,
class_occ: Dict[str, int],
keywords: List[Keyword],
idx_to_class: Dict[int, str],
modalities: Dict[Modality, str],
relative_root: str,
preprocessed: str,
):
# this used later on to identify the task
self.name = name
# permanent task attributes e.g. holds task type, mean and std of train set, relative root path, etc.
self.task_type = task_type
self.means = means
self.stds = stds
self.sizes = sizes
self.class_occ = class_occ
self.relative_root = Path(relative_root) # relative to MMLFileManager.data_path
self.preprocessed = preprocessed
self.keywords = keywords
self.idx_to_class = idx_to_class
self.modalities = modalities
if self.target and self.target not in self.modalities:
warnings.warn(f"Corrupted target for task {self.name}: {self.target}")
# non-permanent attributes, these correspond to experiment specific settings, e.g. a performance, a model
# trained for the task, auto-augmentation, checkpoints, FIM, features, heads, etc...
# be aware that in order to store and load them they should only consist of (stacked) default builtin types!
# e.g. str, int, dict, list, ...
self.paths: Dict[str, Path] = {}
self.models: List[ModelStorage] = []
logger.debug(f"Created TaskStruct-object for task {self.name}.")
[docs]
@staticmethod
def non_permanent_task_attributes() -> Dict[str, Tuple[Callable, Callable]]:
"""
Returns a dict of task attributes that are not part of the task meta information but are computed by MML.
The value within the dict is a tuple of callables representing and instantiating the object (to be compatible
with yaml safe loading and dumping).
:return: dict with str keys, which are the names of task struct attributes that might be set during MML runtime
and tuple vals corresponding to (representer, instantiator)
"""
attrs = {}
# paths attr
def path_representer(path_attr: Dict[str, Path]) -> Dict[str, str]:
return {k: str(v) for k, v in path_attr.items()}
def path_instantiator(path_repr: Dict[str, str]) -> Dict[str, Path]:
return {k: Path(v) for k, v in path_repr.items()}
attrs["paths"] = (path_representer, path_instantiator)
# models attr
def models_representer(model_attr: List[ModelStorage]) -> List[str]:
paths = [str(m._stored) for m in model_attr]
if any(p is None for p in paths):
raise RuntimeError(
"Models that are attached to TaskStructs can only be dumped after a scheduler step "
"if the model has been stored before."
)
return paths
def models_instantiator(model_repr: List[str]) -> List[ModelStorage]:
return [ModelStorage.from_json(path=Path(path)) for path in model_repr]
attrs["models"] = (models_representer, models_instantiator)
return attrs
@functools.cached_property
def num_samples(self) -> int:
"""
Number of training (or unlabeled) samples. See also
`~mml.core.data_loading.task_description.TaskDescription.num_samples`.
:rtype: int
"""
# loading only supported with file manager
if not MMLFileManager.exists():
raise RuntimeError("TaskStruct supports num_samples only with initiated MMLFileManager.")
fm = MMLFileManager.instance()
return fm.load_task_description(fm.data_path / self.relative_root).num_samples
@property
def num_classes(self) -> int:
return len(set(self.idx_to_class.values()))
@property
def target(self) -> Optional[Modality]:
if self.task_type == TaskType.CLASSIFICATION:
return Modality.CLASS
elif self.task_type == TaskType.SEMANTIC_SEGMENTATION:
return Modality.MASK
elif self.task_type == TaskType.MULTILABEL_CLASSIFICATION:
return Modality.CLASSES if Modality.CLASSES in self.modalities else Modality.SOFT_CLASSES
elif self.task_type == TaskType.REGRESSION:
return Modality.VALUE
elif self.task_type == TaskType.NO_TASK:
return None
else:
raise RuntimeError("Unable to determine target!")
@property
@deprecated(reason="task_struct.id is deprecated, use task.name instead", version="0.12.0")
def id(self) -> str:
return self.name
def __str__(self) -> str:
infos = [
f"Task name: {self.name}",
f"Task type: {self.task_type}",
f"Num classes: {self.num_classes}",
f"Means: {self.means}",
f"Stds: {self.stds}",
f"Sizes: {self.sizes}",
f"Class occ: {self.class_occ}",
f"Preprocessed: {self.preprocessed}",
f"Task keywords: {[kw.value for kw in self.keywords]}",
]
for attr in self.non_permanent_task_attributes().keys():
infos.append(f"{attr}: {getattr(self, attr)}")
return "\n".join(infos)
def __repr__(self):
return f"TaskStruct({self.name})"
[docs]
class TaskStructFactory:
"""
Manages to load all necessary TaskStructs for an experiment. Stores created objects and aggregates information (like
sizes) across multiple tasks.
"""
[docs]
def __init__(self, cfg: DictConfig, load: bool = False):
self.cfg = cfg
self.fm = MMLFileManager.instance()
self.container = []
self.sizes = Sizes()
self.reset_sizes()
# load old factory dump
if load:
self.loading_old_dump()
[docs]
def reset_sizes(self) -> None:
"""
Sets the internal sizes back.
:return: None
"""
self.sizes.min_height = 100000
self.sizes.max_height = 0
self.sizes.min_width = 100000
self.sizes.max_width = 0
[docs]
def set_task_struct_defaults(self, task_struct: TaskStruct):
"""
Based on reuse configs this sets the defaults within the task struct regarding previous results. Currently
only supports Path and ModelStorage objects!
:param TaskStruct task_struct: task_struct of the task that values should be loaded
:return: None
"""
if task_struct.name in self.fm.reusables:
for k, v in self.fm.reusables[task_struct.name].items():
if k == "models":
assert isinstance(v, list)
for storage in v:
assert isinstance(storage, ModelStorage)
task_struct.models = v
logger.debug(f"Attached {len(v)} reusable models to {task_struct.name}.")
else:
assert isinstance(v, Path)
assert v.exists()
task_struct.paths[k] = v
logger.debug(f"Set {k} path of {task_struct.name} to {v}.")
else:
logger.debug(f"No reusable for task {task_struct.name}")
[docs]
def loading_old_dump(self) -> None:
"""
Loading is useful if an experiment was aborted and is re-initialized.
:return: None
"""
logger.info(f"Loading task dump from {self.fm.task_dump_path}.")
if not self.fm.task_dump_path.exists():
raise FileNotFoundError(
f"Specified exp folder ({self.fm.task_dump_path.parent}) requested for loading "
f"TaskFactory dump is not existing or has incorrectly saved dump (requires "
f"{self.fm.task_dump_path.name} file)."
)
with open(str(self.fm.task_dump_path), "rb") as f:
all_tasks_dict = orjson.loads(f.read())
logger.info(f"Starting loading of {len(all_tasks_dict)} tasks...")
for name, task_dict in all_tasks_dict.items():
created = self.create_task_struct(name, return_ref=True)
for attr, (_, instantiator) in TaskStruct.non_permanent_task_attributes().items():
if attr in task_dict.keys():
setattr(created, attr, instantiator(task_dict[attr]))
# report sizes
logger.debug(f"Sizes of factory are: {self.sizes}.")
logger.info(f"Successfully loaded. Container includes {len(self.container)} task structs.")
[docs]
def dump(self, clear_container=False) -> None:
"""
Stores current tasks and their attributes.
:param clear_container: if true deletes currently loaded tasks afterwards
:return: None
"""
all_tasks_dict = {}
for task in self.container:
task_dict = {}
for attr, (representer, _) in TaskStruct.non_permanent_task_attributes().items():
if getattr(task, attr) is not None:
task_dict[attr] = representer(getattr(task, attr))
all_tasks_dict[task.name] = task_dict
with open(str(self.fm.task_dump_path), "wb") as f:
f.write(orjson.dumps(all_tasks_dict))
logger.debug(f"Dumped {len(all_tasks_dict)} tasks @ {self.fm.task_dump_path}.")
if clear_container:
self.container = []
self.reset_sizes()
[docs]
def create_task_struct(self, name: str, return_ref=False) -> Union[None, TaskStruct]:
"""
Creates a task struct object via loading necessary information from the meta info json file and adding
reusable information (e.g. intermediate results from previous experiments) as adaption of the already
preprocessed version of the task. Finally, the task struct is added to the internal container.
:param name: name of the task to be created
:param return_ref: if true returns a reference to the created struct, else returns None
:return: either the created task struct or None
"""
if self.check_exists(name=name):
logger.error(f"Task struct {name} to produce already present in the factory container.")
if return_ref:
return self.get_by_name(name=name)
else:
return
# make sure to remove duplicate tag
undup_name = undup_names([name])[0]
# next check if this is a base task that has not yet been created
if (TAG_SEP not in undup_name) and (undup_name not in self.fm.task_index):
# the task is not a tagged one and the base is not present, raise error
raise TaskNotFoundError(
f"Was not able to locate task {undup_name}. You may need to call "
f"<mml create ...> with your current task setting."
)
# next check if this is a tagged task with missing entry with respect to preprocessing
if (TAG_SEP in undup_name) and (
undup_name not in self.fm.task_index or self.cfg.preprocessing.id not in self.fm.task_index[undup_name]
):
if undup_name not in self.fm.task_index:
# task not yet present at all in task index of the file manager, try to auto generate base task
logger.info(f"Task {undup_name} not existent yet. Will try to create.")
try:
with catch_time() as timer:
path = TaskCreator(dset_path=Path("")).auto_create_tagged(
full_alias=undup_name, preprocessing="none"
)
logger.debug(f"Task created successfully within {timer.elapsed:5.2f} seconds.")
except TaskNotFoundError:
raise RuntimeError(f"Unable to auto_create {undup_name} with pp {self.cfg.preprocessing.id}.")
# add to task index of file manager
self.fm.add_to_task_index(path)
# check for inconsistencies
if "none" not in self.fm.task_index[undup_name]:
raise RuntimeError(
f"MML detected a tagged task ({undup_name}) that exists with some "
f"preprocessing(s) ({list(self.fm.task_index[undup_name].keys())}), but "
f"no raw version has been found. This may be either because the raw version "
f"has been removed or you used a previous version of MML to create this "
f"tagged task. From MML 0.13.0 on tagged preprocessing will only be created "
f"with a base tagged task. Consider removing all preprocessed version of "
f"this task to create from scratch "
f"({list(self.fm.task_index[undup_name].values())})."
)
# next check if we need to create a preprocessed version
if self.cfg.preprocessing.id not in self.fm.task_index[undup_name]:
base_task = undup_name[: undup_name.find(TAG_SEP)]
if self.cfg.preprocessing.id in self.fm.task_index[base_task]:
# this indicates the case that we can leverage existing preprocessing!
# preprocessed tagged task not yet present in task index of the file manager
logger.info(
f"Generating description of {undup_name} for preprocessing {self.cfg.preprocessing.id}."
)
try:
with catch_time() as timer:
path = TaskCreator.auto_create_tagged(
full_alias=undup_name, preprocessing=self.cfg.preprocessing.id
)
logger.debug(f"Task created successfully within {timer.elapsed:5.2f} seconds.")
except TaskNotFoundError:
raise RuntimeError(f"Unable to auto_create {undup_name} with pp {self.cfg.preprocessing.id}.")
# add to task index of file manager
self.fm.add_to_task_index(path)
# generate struct from meta info provided by file manager
def_kwargs = self.fm.get_task_info(task_name=undup_name, preprocess=self.cfg.preprocessing.id)
if def_kwargs["name"] != name:
raise RuntimeError(f"Received incorrect task information for task {name} (got {def_kwargs['name']}).")
new_task = TaskStruct(**def_kwargs)
self.container.append(new_task)
# apply defaults to task struct
self.set_task_struct_defaults(new_task)
# update sizes
self.sizes.min_height = min(self.sizes.min_height, new_task.sizes.min_height)
self.sizes.max_height = max(self.sizes.max_height, new_task.sizes.max_height)
self.sizes.min_width = min(self.sizes.min_width, new_task.sizes.min_width)
self.sizes.max_width = max(self.sizes.max_width, new_task.sizes.max_width)
logger.debug(f"New factory sizes are: {self.sizes}")
if return_ref:
return new_task
[docs]
def get_by_name(self, name: str) -> TaskStruct:
"""
Returns the internally stored task_struct corresponding to >name<. Raises an error if not found (or returns
only false if >test< is true).
:param name: task name
:param test: if true does not raise an error
:return: either the task_struct or False if not found and in test mode
"""
for task in self.container:
if task.name == name:
return task
msg = f"Was not able to find requested dataset {name} in the container of produced task structs."
raise TaskNotFoundError(msg)
[docs]
def check_exists(self, name: str) -> bool:
"""
Checks whether a given task is present in the container.
:param str name: task name
:return: True iff task is within container
:rtype: bool
"""
try:
self.get_by_name(name=name)
except TaskNotFoundError:
return False
return True
[docs]
def undup_names(moded_names_list):
"""
This function removes the "duplicate"-suffixes of tasks that is added if some tasks are present for multiple times.
:param moded_names_list: list of strings, task names potentially including the "duplicate"-suffix
:return: list of strings, tasks names without the suffix (if suffix is not present the name stays equal)
"""
return list(
map(
lambda x: str(x)[: (lambda y: None if y == -1 else y)(str(x).find(f"{TAG_SEP}duplicate"))], moded_names_list
)
)