Source code for mml.core.scripts.callbacks

# LICENSE HEADER MANAGED BY add-license-header
#
# SPDX-FileCopyrightText: Copyright 2024 German Cancer Research Center (DKFZ) and contributors.
# SPDX-License-Identifier: MIT
#

import copy
import logging
import os
from pathlib import Path
from typing import Any, Dict, List

import lightning
import torch
from lightning.pytorch.callbacks import ModelCheckpoint, RichProgressBar, TQDMProgressBar

from mml.core.scripts.utils import LearningPhase

logger = logging.getLogger(__name__)


[docs] class StopAfterKeyboardInterrupt(lightning.Callback): """ Ensures pytorch lightning to really shut down after keyboard interrupt. This is the new variant of this Callback aimed to be used by most recent pytorch lightning versions. """
[docs] def on_exception( self, trainer: "lightning.Trainer", pl_module: "lightning.LightningModule", exception: BaseException ) -> None: if trainer.interrupted and isinstance(exception, KeyboardInterrupt): raise InterruptedError( "Trainer has been interrupted by keyboard! " "Will stop running MML - no graceful shutdown, ongoing epoch results are lost! " "Run MML in continue mode to start from the checkpoint of last epochs end." )
[docs] class MetricsTrackerCallback(lightning.Callback): """ Keeps track of all metrics, at the end of each epoch. """
[docs] def __init__(self): self.metrics: List[Dict[str, float]] = []
[docs] def state_dict(self) -> Dict[str, Any]: return {"metrics": self.metrics}
[docs] def load_state_dict(self, state_dict: Dict[str, Any]) -> None: self.metrics = state_dict["metrics"]
[docs] def on_train_epoch_end(self, trainer: "lightning.Trainer", pl_module: "lightning.LightningModule") -> None: self.copy_metrics(trainer=trainer, phase=LearningPhase.TRAIN)
[docs] def on_validation_epoch_end(self, trainer: "lightning.Trainer", pl_module: "lightning.LightningModule") -> None: self.copy_metrics(trainer=trainer, phase=LearningPhase.VAL)
# we need to gather test metrics only at the end (after module.test_epoch_end) to wait for bootstrapped computation
[docs] def on_test_end(self, trainer: "lightning.Trainer", pl_module: "lightning.LightningModule") -> None: self.copy_metrics(trainer=trainer, phase=LearningPhase.TEST)
[docs] def copy_metrics(self, trainer: lightning.Trainer, phase: LearningPhase) -> None: phase_metrics = copy.deepcopy(trainer.callback_metrics) # Dict[str, torch.Tensor] phase_metrics = { name: metric_tensor.item() for name, metric_tensor in phase_metrics.items() if name.startswith(str(phase.value)) } logger.debug(f"{phase=}, logged metrics {phase_metrics.keys()}") if len(self.metrics) == trainer.current_epoch: # first time copying for this epoch self.metrics.append(phase_metrics) elif len(self.metrics) == trainer.current_epoch + 1: # updating this epoch, for example adding train to val metrics self.metrics[trainer.current_epoch].update(phase_metrics) else: # we might have missed some epochs? diff = trainer.current_epoch - len(self.metrics) logger.error(f"There is a discrepancy of {diff} between metrics recorded and the current epoch!") for _ in range(diff): self.metrics.append({}) self.metrics.append(phase_metrics)
[docs] class MMLRichProgressBar(RichProgressBar): """ Slight modification of the Lightning rich progress bar, showing the correct experiment name. """
[docs] def get_metrics(self, trainer, model): # don't show the version number items = super().get_metrics(trainer, model) items.pop("v_num", None) items["exp"] = "/".join(Path(os.getcwd()).parts[-2:]) return items
[docs] class MMLTQDMProgressBar(TQDMProgressBar): """ Slight modification of the Lightning tqdm progress bar, showing the correct experiment name. """
[docs] def __init__(self, refresh_rate=1): super().__init__(refresh_rate=refresh_rate) self.experiment_name = "/".join(Path(os.getcwd()).parts[-2:])
[docs] def get_metrics(self, trainer, model): # don't show the version number items = super().get_metrics(trainer, model) items.pop("v_num", None) items["exp"] = self.experiment_name return items
[docs] class MMLModelCheckpoint(ModelCheckpoint): """ Slight modification of the lightning ModelCheckpoint (see https://github.com/Lightning-AI/pytorch-lightning/issues/20245). """ def _save_last_checkpoint(self, trainer: "lightning.Trainer", monitor_candidates: Dict[str, torch.Tensor]) -> None: """Only update last checkpoint in case there has just been a new checkpoint.""" if self._last_global_step_saved == trainer.global_step: super()._save_last_checkpoint(trainer=trainer, monitor_candidates=monitor_candidates)
[docs] def on_train_epoch_end(self, trainer: "lightning.Trainer", pl_module: "lightning.LightningModule") -> None: """Save a checkpoint at the end of the training epoch.""" if not self._should_skip_saving_checkpoint(trainer) and self._should_save_on_train_epoch_end(trainer): monitor_candidates = self._monitor_candidates(trainer) if self._every_n_epochs >= 1 and (trainer.current_epoch + 1) % self._every_n_epochs == 0: self._save_topk_checkpoint(trainer, monitor_candidates) self._save_last_checkpoint(trainer, monitor_candidates)
[docs] def on_validation_end(self, trainer: "lightning.Trainer", pl_module: "lightning.LightningModule") -> None: """Save a checkpoint at the end of the validation stage.""" if not self._should_skip_saving_checkpoint(trainer) and not self._should_save_on_train_epoch_end(trainer): if self._every_n_epochs >= 1 and (trainer.current_epoch + 1) % self._every_n_epochs == 0: monitor_candidates = self._monitor_candidates(trainer) self._save_topk_checkpoint(trainer, monitor_candidates) self._save_last_checkpoint(trainer, monitor_candidates)