mml.core.data_loading.augmentations.torchvision
- class TorchvisionAugmentationModule[source]
Bases:
AugmentationModuleTorchvision V2 augmentation module.
- __init__(device: str, cfg: ListConfig, is_first: bool, is_last: bool, means: RGBInfo | None, stds: RGBInfo | None, num_classes: int | None = None)[source]
- static from_cfg(aug_config: ListConfig, on_cpu: bool, num_classes: int | None = None) List[Transform][source]
Takes a config and returns a list of corresponding transforms.
- Parameters:
aug_config (DictConfig) – see configs/augmentations/torchvision.yaml for an example, both the “cpu” and the “gpu” attribute may be passed to this function.
on_cpu (bool) – determines if the transforms will be performed on cpu (single sample in a worker) or gpu (batched)
num_classes (Optional[int]) – optional parameter required for MixUp and CutMix transforms
- Returns:
a list of torchvision v2 transforms
- apply_tv_tensor_types(inpt: Dict[str, Dict[str, Any]]) Dict[str, Dict[str, Any]][source]
Turns plain tensors organised in Modality Dict structure to corresponding TV-Tensors.
- mixup_cutmix_labels_getter(batch: Dict[str, Dict[str, Tensor]]) Tensor[source]
Helper function to extract labels from a batch for torchvision v2 MixUp and CutMix transforms. See https://pytorch.org/vision/main/auto_examples/transforms/plot_cutmix_mixup.html#non-standard-input-format.
- Parameters:
batch (Dict[str, Dict[str, torch.Tensor]]) – full batch as returned by dataloader, expects a single task
- Returns:
classification labels of the single task