# LICENSE HEADER MANAGED BY add-license-header
#
# SPDX-FileCopyrightText: Copyright 2024 German Cancer Research Center (DKFZ) and contributors.
# SPDX-License-Identifier: MIT
#
import logging
from copy import deepcopy
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional
from mml.core.data_loading.task_attributes import Keyword, License, Modality, ModalityEntry, RGBInfo, Sizes, TaskType
logger = logging.getLogger(__name__)
# entries required to construct a TaskStruct
STRUCT_REQ_HEADER_KEYS = [
"name",
"task_type",
"keywords",
"modalities",
"means",
"stds",
"sizes",
"idx_to_class",
"class_occ",
]
# plus additional entries that are not yet part of the data
ALL_HEADER_KEYS = STRUCT_REQ_HEADER_KEYS + [
"description",
"creation_protocol",
"reference",
"url",
"download",
"license",
"release",
]
# now these are all entries
ALL_TASK_DESCRIPTION_KEYS = ALL_HEADER_KEYS + ["unlabeled_samples", "train_folds", "train_samples", "test_samples"]
SampleDescription = Dict[Modality, ModalityEntry]
[docs]
@dataclass
class TaskDescription:
"""
A task description holding the meta information on task background as well as the actual links to samples.
"""
# provided
name: Optional[str] = None # renamed from alias
description: str = ""
creation_protocol: str = ""
reference: str = ""
url: str = ""
download: str = ""
license: License = License.UNKNOWN
release: str = ""
task_type: TaskType = TaskType.UNKNOWN
keywords: List[Keyword] = field(default_factory=list) # renamed from tags
# inferred
means: RGBInfo = field(default_factory=RGBInfo)
stds: RGBInfo = field(default_factory=RGBInfo)
sizes: Sizes = field(default_factory=Sizes)
modalities: Dict[Modality, str] = field(default_factory=dict)
idx_to_class: Dict[int, str] = field(default_factory=dict)
class_occ: Dict[str, int] = field(default_factory=dict)
# created
unlabeled_samples: Dict[str, SampleDescription] = field(default_factory=dict) # renamed
train_folds: List[List[str]] = field(default_factory=list)
train_samples: Dict[str, SampleDescription] = field(default_factory=dict) # renamed
test_samples: Dict[str, SampleDescription] = field(default_factory=dict) # renamed
[docs]
def to_json(self) -> Dict[str, Any]:
"""
Helper to transform a TaskDescription into a json compatible dict.
:return: (dic) A dictionary without any custom classes as values to be saved in json format.
"""
json_data = {}
for key in ALL_TASK_DESCRIPTION_KEYS:
value = getattr(self, key)
# transform plain StrEnum to str
if key in ["task_type", "license"]:
value = value.value
# transform listed StrEnum to str
elif key in ["keywords"]:
value = [elem.value for elem in value]
# transform dict StrEnum to str
elif key in ["modalities"]:
value = {k.value: v for k, v in value.items()}
# transform RGBInfo and Sizes
elif key in ["means", "stds", "sizes"]:
value = value.to_list()
# transform nested StrEnum to str
elif key in ["unlabeled_samples", "train_samples", "test_samples"]:
value = {
top_k: {low_k.value: low_v for low_k, low_v in top_v.items()} for top_k, top_v in value.items()
}
# json only allows str keys
elif key in ["idx_to_class"]:
value = {str(k): v for k, v in value.items()}
json_data[key] = value
return json_data
[docs]
@classmethod
def from_json(cls, data_dict: Dict[str, Any]) -> "TaskDescription":
"""
Counterpart for the to_json function: Replaces enum values with their entities and creates a TaskDescription.
:param Dict[str, Any] data_dict: a dictionary without any custom classes as values to be saved in json format.
:return: a TaskDescription with entries as encoded in the data_dict
"""
data_dict = deepcopy(data_dict)
cls_kwargs = {}
for key in data_dict:
if key not in ALL_TASK_DESCRIPTION_KEYS:
raise KeyError(f"Key {key} not part of a TaskDescription!")
value = data_dict[key]
# transform plain StrEnum
if key in ["task_type", "license"]:
C = {"task_type": TaskType, "license": License}[key]
value = C(value)
# transform listed StrEnum
elif key in ["keywords"]:
value = [Keyword(elem) for elem in value]
# transform dict StrEnum
elif key in ["modalities"]:
value = {Modality(k): v for k, v in value.items()}
# transform RGBInfo and Sizes
elif key in ["means", "stds", "sizes"]:
C = {"means": RGBInfo, "stds": RGBInfo, "sizes": Sizes}[key]
value = C(*value)
# transform nested StrEnum to str
elif key in ["unlabeled_samples", "train_samples", "test_samples"]:
value = {
top_k: {Modality(low_k): low_v for low_k, low_v in top_v.items()} for top_k, top_v in value.items()
}
elif key in ["idx_to_class"]:
value = {int(k): v for k, v in value.items()}
cls_kwargs[key] = value
return cls(**cls_kwargs)
@property
def num_samples(self) -> int:
"""
The number or training (or unlabeled) samples.
"""
# list lookup is faster than dict length computation
train_cases = sum(map(len, self.train_folds))
if train_cases > 0:
return train_cases
# we might have an unlabeled task if no train folds are defined
elif len(self.unlabeled_samples) > 0:
return len(self.unlabeled_samples)
# we return zero if neither are found, but log error
logger.error(f"Asked for number of samples of task {self.name}, but did not find any samples")
return 0