mml.core.data_loading.lightning_datamodule

class MultiTaskDataModule[source]

Bases: LightningDataModule

This class wraps one or multiple TaskDataset s for lightning. Given the respective TaskStruct s it takes care of setting up all correct data splits. It particularly interprets the following elements of the config:

Importantly it provides the necessary lightning interface (e.g., setup(), train_dataloader(), etc.).

__init__(task_structs: List[TaskStruct], cfg: DictConfig, fold: int = 0)[source]
get_cpu_transforms(struct: TaskStruct, phase: LearningPhase = LearningPhase.TRAIN) AugmentationModule | AugmentationModuleContainer[source]

Returns the necessary :param struct: :param phase: :return:

static get_dataset_balancing_weights(ds: TaskDataset) Tensor[source]
get_image_normalization(struct: TaskStruct) Tuple[RGBInfo | None, RGBInfo | None][source]

Returns the applied / required image normalization information.

Returns:

tuple of means and stds for each of the channels, in case no normalization is applied returns None for both

Return type:

Tuple[Optional[RGBInfo], Optional[RGBInfo]]

get_loader_kwargs_from_cfg(task_name: str, phase: LearningPhase = LearningPhase.TRAIN) Dict[str, Any][source]
get_modality_loaders() Dict[Modality, ModalityLoader][source]

Creates ModalityLoader instances from the config.

on_after_batch_transfer(batch: Any, dataloader_idx: int) Any[source]

Enables gpu augmentations after batch has been transferred to device.

predict_dataloader(*args, **kwargs) CombinedLoader[source]
prepare_data(*args, **kwargs)[source]
setup(stage: str) None[source]

Implements the lightning interface to prepare the datamodule. In particular sets up the TaskDataset s

Parameters:

stage

Returns:

teardown(stage: str | None = None) None[source]
test_dataloader(*args, **kwargs) CombinedLoader[source]
train_dataloader(*args, **kwargs) CombinedLoader[source]
val_dataloader(*args, **kwargs) CombinedLoader[source]