# LICENSE HEADER MANAGED BY add-license-header
#
# SPDX-FileCopyrightText: Copyright 2024 German Cancer Research Center (DKFZ) and contributors.
# SPDX-License-Identifier: MIT
#
import hashlib
from typing import Any, Dict, List, Optional
import albumentations as A
import numpy as np
from albumentations.pytorch import ToTensorV2
from omegaconf import ListConfig, OmegaConf
from mml.core.data_loading.augmentations.augmentation_module import (
IMAGENET_AA_PATH,
AugmentationModule,
DataFormat,
Transform,
)
from mml.core.data_loading.file_manager import MMLFileManager
from mml.core.data_loading.task_attributes import RGBInfo
ALBUMENTATIONS_VALID_MODALITIES = ["image", "mask"]
[docs]
class AlbumentationsAugmentationModule(AugmentationModule):
"""
Albumentations augmentation module.
Basic supported dict entries are ['image', 'mask', 'bboxes', 'keypoints'].
An AutoAlbument generated pipeline is available via the Identifier "ImageNetAA" (no parameters). "RandAugment" with
parameters is also provided on top. See :meth:`get_rand_augment` for details.
Note that once bboxes and keypoints will be supported than composition will include the respective parameters
https://albumentations.ai/docs/api_reference/core/composition/. Futhermore "additional_targets" might need to be
defined. To check whether a certain augmentation supports a specific target type see
https://albumentations.ai/docs/getting_started/transforms_and_targets/.
"""
[docs]
def __init__(
self,
device: str,
cfg: ListConfig,
is_first: bool,
is_last: bool,
means: Optional[RGBInfo],
stds: Optional[RGBInfo],
floatify: bool = False,
tensorize: bool = True,
):
self.tensorize = tensorize
self.floatify = floatify
if floatify and not tensorize:
raise ValueError("floatify is to ensure float tensors are moved to device")
super().__init__(device=device, cfg=cfg, is_first=is_first, is_last=is_last, means=means, stds=stds)
def _build_pipeline(self):
transforms = self.from_cfg(self.cfg)
norm_trans = []
_is_float = False
if self.means is None and self.stds is None:
if self.is_last:
# no normalization requested, but this is the last transform, we need to make sure to have float values
norm_trans.append(A.ToFloat(max_value=255))
_is_float = True
elif sum([x is None for x in [self.means, self.stds]]) == 1:
raise RuntimeError(
"Was presented either only STD or only MEAN normalization values. Require either none or both!"
)
else:
# default case: requested normalization
norm_trans.append(A.Normalize(mean=self.means.get_rgb(), std=self.stds.get_rgb()))
_is_float = True
# if this is the last transform before moving to gpu we need to make sure to have float tensors
# to support lightning precision
if self.floatify and not _is_float:
norm_trans.append(A.ToFloat(max_value=255))
if self.tensorize:
norm_trans.append(ToTensorV2())
self.pipeline = A.Compose([*transforms, *norm_trans])
def _forward_impl(self, inpt: Dict[str, Any]) -> Dict[str, Any]:
# albumentations only handles single sample inputs, receive a dict and return a dict
clone = {mod: inpt[mod] for mod in ALBUMENTATIONS_VALID_MODALITIES if mod in inpt}
outpt = self.pipeline(**clone)
outpt.update({k: v for k, v in inpt.items() if k not in ALBUMENTATIONS_VALID_MODALITIES})
return outpt
def _sanity_check(self, inpt: Any) -> None:
# albumentations sanity checks
assert self.device == "cpu"
assert self.data_format == DataFormat.SINGLE_SAMPLE_DICT
def __hash__(self):
"""MD5 Hash value of the pipeline."""
path = MMLFileManager.instance().construct_saving_path(None, key="temp", file_name="hash_dump.json")
A.save(self.pipeline, str(path))
block_size = 65536
hasher = hashlib.md5()
with open(str(path), "rb") as file:
buf = file.read(block_size)
while len(buf) > 0:
hasher.update(buf)
buf = file.read(block_size)
return hasher.hexdigest()
[docs]
@staticmethod
def from_cfg(aug_config: ListConfig) -> List[Transform]:
"""
Takes a config and returns a list of corresponding transforms.
:param DictConfig aug_config: see configs/augmentations/default.yaml for an example, the "cpu" attribute is what
has to be passed to this function.
:return: a list of albumentation transforms
"""
aug_config = OmegaConf.to_container(aug_config, resolve=True)
transforms = []
for transform_args in aug_config:
transform_name = transform_args.pop("name")
if transform_name == "RandAugment":
transform = AlbumentationsAugmentationModule.get_rand_augment(**transform_args)
elif transform_name == "ImageNetAA":
transform: Transform = A.load(IMAGENET_AA_PATH) # type: ignore
else:
transform = A.from_dict( # type: ignore
{
"transform": {
"__class_fullname__": "albumentations.augmentations.transforms." + transform_name,
**transform_args,
}
}
)
transforms.append(transform)
return transforms
[docs]
@staticmethod
def get_rand_augment(
number: int, magnitude: int, p: float = 1.0, mode: str = "all", cut_out: bool = False
) -> A.BaseCompose:
"""
Gets RandAugment transform. For details see https://arxiv.org/abs/1909.13719.
:param number: number of transforms to be applied (excluding cut_out if active)
:param magnitude: int between 0 and 9 determining strength of transformation
:param p: probability to apply RandAugment
:param mode: either 'geo' for geometrical transforms, 'color' for color transforms or 'all' for both of them
:param cut_out: indicating if cutout should be applied
:return: an albumentations transform
"""
MAX_MAGNITUDE = 10
assert 0 < magnitude < MAX_MAGNITUDE, f"magnitude range is 1 - 9, was given {magnitude}"
assert mode in ["geo", "color", "all"], f"incorrect RandAugment mode {mode}, provide one of [color, geo, all]"
ops = [ # 0 - geometrical
A.Affine(
translate_percent=(
float(-np.linspace(0, 1, MAX_MAGNITUDE)[magnitude]),
float(np.linspace(0, 1, MAX_MAGNITUDE)[magnitude]),
),
p=p,
cval_mask=255,
),
A.Affine(
rotate=(
float(-np.linspace(0, 45, MAX_MAGNITUDE)[magnitude]),
float(np.linspace(0, 45, MAX_MAGNITUDE)[magnitude]),
),
p=p,
cval_mask=255,
),
A.Affine(
scale=(
1 + float(-np.linspace(0, 0.5, MAX_MAGNITUDE)[magnitude]),
1 + float(np.linspace(0, 0.5, MAX_MAGNITUDE)[magnitude]),
),
p=p,
cval_mask=255,
keep_ratio=True,
),
A.Affine(
shear=(
float(-np.linspace(0, 35, MAX_MAGNITUDE)[magnitude]),
float(np.linspace(0, 35, MAX_MAGNITUDE)[magnitude]),
),
p=p,
cval_mask=255,
),
# 4 - Color Based
A.InvertImg(p=p),
A.Equalize(p=p),
A.Solarize(threshold=float(np.linspace(0, 256, MAX_MAGNITUDE)[magnitude]), p=p),
A.Posterize(num_bits=int(np.linspace(0, 8, MAX_MAGNITUDE)[magnitude]), p=p),
A.RandomBrightnessContrast(
brightness_limit=float(np.linspace(0, 0.8, MAX_MAGNITUDE)[magnitude]), contrast_limit=0.0, p=p
),
A.RandomBrightnessContrast(
contrast_limit=float(np.linspace(0, 0.8, MAX_MAGNITUDE)[magnitude]), brightness_limit=0.0, p=p
),
A.Sharpen(
alpha=(0.1, float(np.linspace(0.1, 0.9, MAX_MAGNITUDE)[magnitude])),
lightness=(0.4, float(np.linspace(0.4, 1.0, MAX_MAGNITUDE)[magnitude])),
p=p,
),
]
if mode == "geo":
ops = ops[:4]
elif mode == "color":
ops = ops[5:]
else:
ops = ops
transforms = A.SomeOf(transforms=ops, n=number, replace=True, p=1)
if cut_out:
transforms = A.Sequential(
transforms=[
transforms,
A.CoarseDropout(
num_holes_range=(4, 8),
hole_height_range=(0.05, float(np.linspace(0.05, 0.2, MAX_MAGNITUDE)[magnitude])),
hole_width_range=(0.05, float(np.linspace(0.05, 0.2, MAX_MAGNITUDE)[magnitude])),
mask_fill_value=255,
p=p,
),
]
)
return transforms