Source code for mml.core.data_loading.augmentations.kornia

# LICENSE HEADER MANAGED BY add-license-header
#
# SPDX-FileCopyrightText: Copyright 2024 German Cancer Research Center (DKFZ) and contributors.
# SPDX-License-Identifier: MIT
#

from typing import Any, List, Optional

import kornia as K
import torch
from hydra.utils import instantiate
from omegaconf import ListConfig, OmegaConf

from mml.core.data_loading.augmentations.augmentation_module import AugmentationModule, DataFormat
from mml.core.data_loading.task_attributes import Modality, RGBInfo

KORNIA_VALID_MODALITIES = ["image", "mask", "bbox", "keypoints"]


[docs] class KorniaAugmentationModule(AugmentationModule): """ Kornia augmentation module. Basic supported dict entries are ['image', 'mask', 'bbox', 'keypoints'] """
[docs] def __init__( self, device: str, cfg: ListConfig, is_first: bool, is_last: bool, means: Optional[RGBInfo], stds: Optional[RGBInfo], ): self.data_keys = None super().__init__(device=device, cfg=cfg, is_first=is_first, is_last=is_last, means=means, stds=stds) self.run_inverse = False
def _build_pipeline(self): transforms = self.from_cfg(self.cfg) norm_trans = [] if self.means is None and self.stds is None: pass elif sum([x is None for x in [self.means, self.stds]]) == 1: raise RuntimeError( "Was presented either only STD or only MEAN normalization values. Require either none or both!" ) else: norm_trans.append(K.augmentation.Normalize(mean=self.means.get_rgb(), std=self.stds.get_rgb())) aug = K.augmentation.container.AugmentationSequential(*transforms, *norm_trans) # deactivate gradients of augmentations forward aug.forward = torch.no_grad()(aug.forward) self.pipeline = aug def _forward_impl(self, inpt: Any) -> Any: if self.data_format == DataFormat.BATCHED_SAMPLE_DICTS: inpt = {"dummy": inpt} for task in inpt: sub_batch = inpt[task] # use order from dataset aug_modalities = [mod for mod in sub_batch if mod in KORNIA_VALID_MODALITIES] if Modality.MASK.value in aug_modalities: # kornia requires mask to be float and same dimensions as image sub_batch[Modality.MASK.value] = sub_batch[Modality.MASK.value].unsqueeze(1).float() # disassemble batch batch_list = [sub_batch[mod] for mod in aug_modalities] # augment if self.run_inverse: augmented = self.pipeline.inverse(*batch_list, data_keys=aug_modalities) else: augmented = self.pipeline(*batch_list, data_keys=aug_modalities) # re-assemble if len(aug_modalities) == 1: # kornia treats this case differently and returns tensor instead of list sub_batch.update({aug_modalities[0]: augmented}) else: sub_batch.update({mod: augmented[ix] for ix, mod in enumerate(aug_modalities)}) if Modality.MASK.value in aug_modalities: # undo kornia requirement of mask to be float and same dimensions as image sub_batch[Modality.MASK.value] = sub_batch[Modality.MASK.value].squeeze(1).long() if self.data_format == DataFormat.BATCHED_SAMPLE_DICTS: inpt = inpt["dummy"] return inpt
[docs] def inverse(self, inpt: Any) -> Any: """ API to use the feature of kornia augmentations to invertible. Reverts the last augmentation. Useful for TTA. :param inpt: a transformed batch (potentially after inference by a model) :return: the model results with undone geometry """ self.run_inverse = True outpt = self.__call__(inpt) self.run_inverse = False return outpt
def _sanity_check(self, inpt: Any) -> None: assert self.device == "gpu", "so far only gpu is supported for KorniaAugmentationModule" assert self.data_format in [DataFormat.MULTI_TASK_SAMPLE_DICTS, DataFormat.BATCHED_SAMPLE_DICTS] modalities = [] if self.data_format == DataFormat.MULTI_TASK_SAMPLE_DICTS: for task in inpt: modalities.extend(list(inpt[task].keys())) else: modalities = list(inpt.keys()) modalities = [ mod for mod in modalities if mod not in [Modality.CLASS.value, Modality.CLASSES.value, Modality.SOFT_CLASSES.value, Modality.SAMPLE_ID.value] ] if any([mod not in KORNIA_VALID_MODALITIES for mod in modalities]): raise ValueError(f"Some modality might not be supported by kornia backend: {modalities}") self.data_keys = modalities
[docs] @staticmethod def from_cfg(aug_config: ListConfig) -> List[K.augmentation.AugmentationBase2D]: """ Takes a config and returns a list of corresponding transforms. :param DictConfig aug_config: see configs/augmentations/kornia.yaml for an example :return: a list of kornia transforms """ aug_config = OmegaConf.to_container(aug_config, resolve=True) transforms = [] for transform_args in aug_config: transform_name = transform_args.pop("name") if transform_name == "RandAugment": transform = K.auto.RandAugment(**transform_args) elif transform_name == "AutoAugment": transform = K.auto.AutoAugment(**transform_args) elif transform_name == "TrivialAugment": transform = K.auto.TrivialAugment(**transform_args) else: _dict = {"_target_": "kornia.augmentation._2d." + transform_name} _dict.update(**transform_args) transform = instantiate(_dict) transforms.append(transform) return transforms