mml.core.models.lightning_single_frame
- class SingleFrameLightningModule[source]
Bases:
LightningModule- __init__(task_structs: List[TaskStruct], cfg: DictConfig, task_weights: List[float] | None = None, load_parameters: Path | None = None)[source]
The default MML lightning module supporting frame wise training and inference.
- Parameters:
task_structs (List[TaskStruct]) –
TaskStructfor all tasks that the model shall interact uponcfg (DictConfig) – the main config file, will use multiple config groups (e.g., arch, loss, sampling, logging, tta, metrics, ..)
task_weights (Optional[List[float]]) – if provided this determines a specific weighting of tasks for the loss, if None all tasks contribute equally
load_parameters (Optional[Path]) – if given will load model and weights, this ignores the current cfg.arch and if any of the task_structs has already a model head with the loaded model it will be reused (otherwise a new head is created for every task)
- compute_and_log_loss(logits: Dict[str, Tensor], targets: Dict[str, Tensor], phase: LearningPhase) Tensor[source]
- compute_and_log_metrics(logits: Dict[str, Tensor], targets: Dict[str, Tensor], phase: LearningPhase)[source]
- forward(x: Tensor) Dict[str, Tensor][source]
Default forward method, this is not used from within pytorch lightning itself. It is provided to the outside as inference option.
- Parameters:
x (torch.Tensor) – plain batch or single image (no modality dict!)
- Returns:
dict with one entry per model head and corresponding prediction logits
- forward_features(x: Tensor) Tensor[source]
Special forward method generating embeddings for images, this is not used from within pytorch lightning itself. It is provided to the outside as embedding generator option.
- Parameters:
x (torch.Tensor) – plain batch or single image (no modality dict!)
- Returns:
tensor of shape num_samples x num_features # TODO verify
- get_criteria() ModuleDict[source]
Generates the criteria modules. These correspond to the loss functions of each task. This is run once at the initialisation of the lightning module.
- Returns:
a dict of task to loss module
- get_metrics(struct: TaskStruct) List[Metric][source]
Generates a collection of metrics, suited for the given task, based on the configs.
- Parameters:
struct (TaskStruct) – struct of the task
- Returns:
a list of torchmetrics metrics
- Return type:
List[torchmetrics.Metric]
- static get_monitor_metric() Tuple[str, str][source]
Returns the monitoring metric. This is used by Lightning to determine best model after training.
- property is_tuning: bool
Checks if the model is currently being tuned, which allows to modify some operations.
- log_confusion_matrix(phase: LearningPhase) None[source]
Logging utility for showing the confusion matrix of each epoch. Each logging also resets the cm in preparation for the next epoch.
- Parameters:
phase (LeaningPhase) – currently active learning phase to separate train, val and test
- Returns:
- log_images_prediction_reference(batch: Dict[str, Dict[str, Tensor]], logits: Dict[str, Tensor], targets: Dict[str, Tensor], phase: LearningPhase) None[source]
Logging utility for showing image examples together with reference and model predictions.
- Parameters:
batch (Dict[str, Dict[str, torch.Tensor]]) – batch as provided by dataloader (batch[task][modality])
logits (Dict[str, torch.Tensor]) – logits as provided by model :meth:step
targets (Dict[str, torch.Tensor]) – targets as provided by :meth:step
phase (LearningPhase) – may be either train, val or test, used to access underlying
data_loading:task_dataset:TaskDatasetand as a logging prefix
- Returns:
- push_and_sort(batch: Dict[str, Dict[str, Tensor]], raise_on_error: bool = True, perform_tta: bool = False) Tuple[Dict[str, Tensor], Dict[str, Tensor]][source]
The lightning internal used “forward” method for dict based dataloaders. It deals with the dict input of the combined dataloader in any mode but “sequential” and resolves the modalities as well as tasks.
- Parameters:
batch (Dict[str, Dict[str, torch.Tensor]]) – a batch of format {task_name: {modality_name: tensor_values}}
raise_on_error (bool) – if False accepts missing targets in the batch (e.g. during test step)
perform_tta (bool) – if True performs multiple forward passes with augmented batch variants and merges them
- Returns:
a tuple consisting of logits dict and targets dict which with keys for each task
- Return type:
Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]
- reformat_batch_from_sequential(batch: Dict[str, Tensor], dataloader_idx: int) Dict[str, Dict[str, Tensor]][source]
Prepares the batch format of “sequential” mode combined loader to default format.
- Parameters:
batch (Dict[str, torch.Tensor]) – a batch of format {modality_name: tensor_values}
dataloader_idx (int) – index of the dataloader
- Returns:
a batch of format {task_name: {modality_name: tensor_values}}
- Return type:
Dict[str, Dict[str, torch.Tensor]]