mml.core.data_loading.utils

one_hot_mask(mask: Tensor, num_classes: int) Tensor[source]

Expects a batched segmentation mask (B x H x W) and returns a batched one-hot encoding like (B x C x H x W). Handles pseudo entry 255 within.

Parameters:
  • mask (torch.Tensor) – the mask to be transformed (B x H x W), values must be below num_classes

  • num_classes (int) – the number of classes

Raises:

ValError – if mask values are outside [0, …, num_classes - 1]

Returns:

a batched one-hot encoding like (B x C x H x W)

Return type:

torch.Tensor