mml.core.data_loading.lightning_datamodule
- class MultiTaskDataModule[source]
Bases:
LightningDataModuleThis class wraps one or multiple
TaskDatasets for lightning. Given the respectiveTaskStructs it takes care of setting up all correct data splits. It particularly interprets the following elements of the config:loaders: the
ModalityLoaderspreprocessing: the
AugmentationModuleaugmentations: also
AugmentationModuleall aspects with respect to sampling and the dataloader
Importantly it provides the necessary lightning interface (e.g.,
setup(),train_dataloader(), etc.).- __init__(task_structs: List[TaskStruct], cfg: DictConfig, fold: int = 0)[source]
- get_cpu_transforms(struct: TaskStruct, phase: LearningPhase = LearningPhase.TRAIN) AugmentationModule | AugmentationModuleContainer[source]
Returns the necessary :param struct: :param phase: :return:
- static get_dataset_balancing_weights(ds: TaskDataset) Tensor[source]
- get_image_normalization(struct: TaskStruct) Tuple[RGBInfo | None, RGBInfo | None][source]
Returns the applied / required image normalization information.
- get_loader_kwargs_from_cfg(task_name: str, phase: LearningPhase = LearningPhase.TRAIN) Dict[str, Any][source]
- get_modality_loaders() Dict[Modality, ModalityLoader][source]
Creates ModalityLoader instances from the config.
- on_after_batch_transfer(batch: Any, dataloader_idx: int) Any[source]
Enables gpu augmentations after batch has been transferred to device.
- setup(stage: str) None[source]
Implements the lightning interface to prepare the datamodule. In particular sets up the
TaskDatasets- Parameters:
stage
- Returns: