Source code for mml.core.data_preparation.data_archive

# LICENSE HEADER MANAGED BY add-license-header
#
# SPDX-FileCopyrightText: Copyright 2024 German Cancer Research Center (DKFZ) and contributors.
# SPDX-License-Identifier: MIT
#

import hashlib
import logging
from dataclasses import dataclass
from pathlib import Path
from typing import Optional

from mml.core.scripts import utils as core_utils

logger = logging.getLogger(__name__)


[docs] class DataKind(core_utils.StrEnum): """ Kinds of data. Used to somehow sort into distinct top level folders. If multiple kinds are mixed, MIXED should be used as default. Usage is not enforced but may help to structure any data storage. """ MIXED = "mixed_data" UNLABELED_DATA = "unlabeled_data" TRAINING_DATA = "training_data" TRAINING_LABELS = "training_labels" TESTING_DATA = "testing_data" TESTING_LABELS = "testing_labels"
[docs] @dataclass class DataArchive: """A simple dataclass holding information about an data archive (e.g. a zipfile).""" path: Path # path to the archive kind: DataKind = DataKind.MIXED # datakind md5sum: Optional[str] = None # is there a md5 sum? password: Optional[str] = None # is there a password encryption? keep_top_level: bool = False # should an additional layer be created during extraction?
[docs] def check_hash(self) -> None: """ Checks if the optional md5sum of the DataArchive matches the actual files md5sum. """ if self.md5sum is None: return block_size = 65536 hasher = hashlib.md5() with open(str(self.path), "rb") as file: buf = file.read(block_size) while len(buf) > 0: hasher.update(buf) buf = file.read(block_size) if hasher.hexdigest() != self.md5sum: raise RuntimeError( f"incorrect md5sum for file {self.path.name}, should be {self.md5sum} but is {hasher.hexdigest()}" ) logger.info(f"file {self.path.name} has correct hash!")