mml.core.data_loading.augmentations.mixup_cutmix

Adapted from

https://github.com/veritable-tech/pytorch-lightning-spells/blob/master/pytorch_lightning_spells/callbacks.py

Which was adapted from

https://github.com/rwightman/pytorch-image-models/blob/8c9814e3f500e8b37aae86dd4db10aba2c295bd2/timm/data/mixup.py

Which was partly adapted from

https://github.com/clovaai/CutMix-PyTorch

Papers:

MixUp: Beyond Empirical Risk Minimization (https://arxiv.org/abs/1710.09412) CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features (https://arxiv.org/abs/1905.04899)

References:

rwightman/pytorch-image-models/ veritable-tech/pytorch-lightning-spells

class CutMixCallback[source]

Bases: MixingCallback

__init__(alpha: float = 0.4, label_smoothing: float = 0.0, minmax: Tuple[float, float] | None = None)[source]
Callback that performs CutMix augmentation on training batches. Incorporates two strategies:
  • bounding box sizes are either controlled via a beta ditribution controlled by parameter alpha

  • or if set minmax controls relative bbox ratios and the distribution is more uniformly

Parameters:
  • alpha (float) – if minmax is None this value controls the beta distribution

  • label_smoothing (float) – if greater than 0 activates label smoothing

  • minmax (Optional[Tuple[float, float]]) – min and max bbox ratios (as percent of image size), typical values for minmax are in the .2-.3 for min and .8-.9 range for max.

get_bbox_and_lam(img_shape: Tuple, lam: float) Tuple[Tuple[ndarray, ndarray, ndarray, ndarray], ndarray][source]

Generate bbox and apply lambda correction.

on_train_batch_start(trainer: Trainer, pl_module: LightningModule, batch: Dict[str, Dict[str, Tensor]], batch_idx: int) None[source]

Will be triggered on each training batch.

rand_bbox(img_shape: Tuple, lam: float, margin: float = 0.0, count: int | None = None) Tuple[ndarray, ndarray, ndarray, ndarray][source]

Standard CutMix bounding-box. Generates a random square bbox based on lambda value. This implementation includes support for enforcing a border margin as percent of bbox dimensions.

Parameters:
  • img_shape – image shape as tuple

  • lam – cutmix lambda value

  • margin – percentage of bbox dimension to enforce as margin (reduce amount of box outside image)

  • count – number of bbox to generate

Returns:

rand_bbox_minmax(img_shape: Tuple[float, float, float], count: int | None = None) Tuple[ndarray, ndarray, ndarray, ndarray][source]

Alternative min-max cutmix bounding-box. Inspired by Darknet cutmix implementation, generates a random rectangular bbox based on min/max percent values applied to each dimension of the input image.

Parameters:
  • img_shape – image shape as tuple

  • count – number of bbox to generate

Returns:

bounding box positions for the full batch

class MixUpCallback[source]

Bases: MixingCallback

__init__(alpha: float = 0.4, label_smoothing: float = 0.0)[source]

Callback that performs MixUp augmentation on training batches.

Parameters:
  • alpha (float) – controls the mixing factor (between 0 and 1)

  • label_smoothing (float) – if greater than 0 activates label smoothing

on_train_batch_start(trainer: Trainer, pl_module: LightningModule, batch: Dict[str, Dict[str, Tensor]], batch_idx: int) None[source]

Will be triggered on each training batch.

class MixingCallback[source]

Bases: Callback

__init__(alpha: float = 0.4, label_smoothing: float = 0.0)[source]

Base class for MML data mixing callbacks.

Parameters:
  • alpha (float) – controls the mixing factor (between 0 and 1)

  • label_smoothing (float) – if greater than 0 activates label smoothing

mixup_targets(targets: Tensor, lambdas: ndarray, task: str) Tensor[source]

Takes care of mixing targets.

Parameters:
  • targets (torch.Tensor) – batched task targets, first target will be mixed with last, second with second to last, etc.

  • lambdas (np.ndarray) – actual fractions of each mix

  • task (str) – name of the task

Returns:

(optionally) smoothed and then mixed targets

setup(trainer: Trainer, pl_module: LightningModule, stage: str | None = None) None[source]

During set up the task structs are inspected for compatibility and the number of classes is extracted. Due to data mixing during training the torchmetrics train metrics are deactivated.

smooth_one_hot(x: Tensor, task: str) Tensor[source]

One hot encoding for a tasks targets with smoothing controlled via label_smoothing.

Parameters:
  • x (torch.Tensor) – batched task targets

  • task (str) – name of the task

Returns:

one hot encoded task targets, smoothed if label_smoothing > 0