mml.core.data_loading.augmentations.torchvision

class TorchvisionAugmentationModule[source]

Bases: AugmentationModule

Torchvision V2 augmentation module.

__init__(device: str, cfg: ListConfig, is_first: bool, is_last: bool, means: RGBInfo | None, stds: RGBInfo | None, num_classes: int | None = None)[source]
static from_cfg(aug_config: ListConfig, on_cpu: bool, num_classes: int | None = None) List[Transform][source]

Takes a config and returns a list of corresponding transforms.

Parameters:
  • aug_config (DictConfig) – see configs/augmentations/torchvision.yaml for an example, both the “cpu” and the “gpu” attribute may be passed to this function.

  • on_cpu (bool) – determines if the transforms will be performed on cpu (single sample in a worker) or gpu (batched)

  • num_classes (Optional[int]) – optional parameter required for MixUp and CutMix transforms

Returns:

a list of torchvision v2 transforms

apply_tv_tensor_types(inpt: Dict[str, Dict[str, Any]]) Dict[str, Dict[str, Any]][source]

Turns plain tensors organised in Modality Dict structure to corresponding TV-Tensors.

See https://pytorch.org/vision/stable/tv_tensors.html

Parameters:

inpt (Dict[str, Dict[str, Any]]) – input, must be in DataFormat.MULTI_TASK_SAMPLE_DICTS

Returns:

same batch, but all tensors are wrapped by corresponding tv_tensors

mixup_cutmix_labels_getter(batch: Dict[str, Dict[str, Tensor]]) Tensor[source]

Helper function to extract labels from a batch for torchvision v2 MixUp and CutMix transforms. See https://pytorch.org/vision/main/auto_examples/transforms/plot_cutmix_mixup.html#non-standard-input-format.

Parameters:

batch (Dict[str, Dict[str, torch.Tensor]]) – full batch as returned by dataloader, expects a single task

Returns:

classification labels of the single task