mml.core.models.lightning_single_frame

class SingleFrameLightningModule[source]

Bases: LightningModule

The default MML lightning module supporting frame wise training and inference.

__init__(task_structs: List[TaskStruct], cfg: DictConfig, weights: List[float] | None = None)[source]
compute_and_log_loss(logits: Dict[str, Tensor], targets: Dict[str, Tensor], phase: LearningPhase) Tensor[source]
compute_and_log_metrics(logits: Dict[str, Tensor], targets: Dict[str, Tensor], phase: LearningPhase)[source]
configure_optimizers()[source]
forward(x: Tensor) Dict[str, Tensor][source]

Default forward method, this is not used from within pytorch lightning itself. It is provided to the outside as inference option.

Parameters:

x (torch.Tensor) – plain batch or single image (no modality dict!)

Returns:

dict with one entry per model head and corresponding prediction logits

forward_features(x: Tensor) Tensor[source]

Special forward method generating embeddings for images, this is not used from within pytorch lightning itself. It is provided to the outside as embedding generator option.

Parameters:

x (torch.Tensor) – plain batch or single image (no modality dict!)

Returns:

tensor of shape num_samples x num_features # TODO verify

get_criteria() ModuleDict[source]

Generates the criteria modules. These correspond to the loss functions of each task. This is run once at the initialisation of the lightning module.

Returns:

a dict of task to loss module

get_metrics(struct: TaskStruct) List[Metric][source]

Generates a collection of metrics, suited for the given task, based on the configs.

Parameters:

struct (TaskStruct) – struct of the task

Returns:

a list of torchmetrics metrics

Return type:

List[torchmetrics.Metric]

static get_monitor_metric() Tuple[str, str][source]

Returns the monitoring metric. This is used by Lightning to determine best model after training.

property is_tuning: bool

Checks if the model is currently being tuned, which allows to modify some operations.

log_confusion_matrix(phase: LearningPhase) None[source]

Logging utility for showing the confusion matrix of each epoch. Each logging also resets the cm in preparation for the next epoch.

Parameters:

phase (LeaningPhase) – currently active learning phase to separate train, val and test

Returns:

log_images_prediction_reference(batch: Dict[str, Dict[str, Tensor]], logits: Dict[str, Tensor], targets: Dict[str, Tensor], phase: LearningPhase) None[source]

Logging utility for showing image examples together with reference and model predictions.

Parameters:
  • batch (Dict[str, Dict[str, torch.Tensor]]) – batch as provided by dataloader (batch[task][modality])

  • logits (Dict[str, torch.Tensor]) – logits as provided by model :meth:step

  • targets (Dict[str, torch.Tensor]) – targets as provided by :meth:step

  • phase (LearningPhase) – may be either train, val or test, used to access underlying :class:~mml.core.data_loading:task_dataset:TaskDataset and as a logging prefix

Returns:

on_test_epoch_end() None[source]
on_train_epoch_end() None[source]
on_validation_epoch_end() None[source]
predict_step(batch: Dict[str, Tensor], batch_idx: int, dataloader_idx: int = 0) Any[source]
push_and_sort(batch: Dict[str, Dict[str, Tensor]], raise_on_error: bool = True, perform_tta: bool = False) Tuple[Dict[str, Tensor], Dict[str, Tensor]][source]

The lightning internal used “forward” method for dict based dataloaders. It deals with the dict input of the combined dataloader in any mode but “sequential” and resolves the modalities as well as tasks.

Parameters:
  • batch (Dict[str, Dict[str, torch.Tensor]]) – a batch of format {task_name: {modality_name: tensor_values}}

  • raise_on_error (bool) – if False accepts missing targets in the batch (e.g. during test step)

  • perform_tta (bool) – if True performs multiple forward passes with augmented batch variants and merges them

Returns:

a tuple consisting of logits dict and targets dict which with keys for each task

Return type:

Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]

reformat_batch_from_sequential(batch: Dict[str, Tensor], dataloader_idx: int) Dict[str, Dict[str, Tensor]][source]

Prepares the batch format of “sequential” mode combined loader to default format.

Parameters:
  • batch (Dict[str, torch.Tensor]) – a batch of format {modality_name: tensor_values}

  • dataloader_idx (int) – index of the dataloader

Returns:

a batch of format {task_name: {modality_name: tensor_values}}

Return type:

Dict[str, Dict[str, torch.Tensor]]

setup_redirection(head: str, task: str) None[source]

Sets up a redirection to use model head “old” for data from task “new”. This also includes preparation to use metrics and cm logging with “new” task name.

Parameters:
  • head (str) – the existing model head name (likely learned before)

  • task (str) – the new task that shall be passed through the old head

Returns:

test_step(batch: Dict[str, Tensor], batch_idx: int, dataloader_idx: int = 0) Tensor[source]
training_step(batch: Dict[str, Dict[str, Tensor]], batch_idx: int) Tensor[source]
validation_step(batch: Dict[str, Tensor], batch_idx: int, dataloader_idx: int = 0) Tensor[source]