# LICENSE HEADER MANAGED BY add-license-header
#
# SPDX-FileCopyrightText: Copyright 2024 German Cancer Research Center (DKFZ) and contributors.
# SPDX-License-Identifier: MIT
#
import copy
import logging
import os
from pathlib import Path
from typing import Any, Dict, List
import lightning
import torch
from lightning.pytorch.callbacks import ModelCheckpoint, RichProgressBar, TQDMProgressBar
from mml.core.scripts.utils import LearningPhase
logger = logging.getLogger(__name__)
[docs]
class StopAfterKeyboardInterrupt(lightning.Callback):
"""
Ensures pytorch lightning to really shut down after keyboard interrupt. This is the new variant of this Callback
aimed to be used by most recent pytorch lightning versions.
"""
[docs]
def on_exception(
self, trainer: "lightning.Trainer", pl_module: "lightning.LightningModule", exception: BaseException
) -> None:
if trainer.interrupted and isinstance(exception, KeyboardInterrupt):
raise InterruptedError(
"Trainer has been interrupted by keyboard! "
"Will stop running MML - no graceful shutdown, ongoing epoch results are lost! "
"Run MML in continue mode to start from the checkpoint of last epochs end."
)
[docs]
class MetricsTrackerCallback(lightning.Callback):
"""
Keeps track of all metrics, at the end of each epoch.
"""
[docs]
def __init__(self):
self.metrics: List[Dict[str, float]] = []
[docs]
def state_dict(self) -> Dict[str, Any]:
return {"metrics": self.metrics}
[docs]
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
self.metrics = state_dict["metrics"]
[docs]
def on_train_epoch_end(self, trainer: "lightning.Trainer", pl_module: "lightning.LightningModule") -> None:
self.copy_metrics(trainer=trainer, phase=LearningPhase.TRAIN)
[docs]
def on_validation_epoch_end(self, trainer: "lightning.Trainer", pl_module: "lightning.LightningModule") -> None:
self.copy_metrics(trainer=trainer, phase=LearningPhase.VAL)
# we need to gather test metrics only at the end (after module.test_epoch_end) to wait for bootstrapped computation
[docs]
def on_test_end(self, trainer: "lightning.Trainer", pl_module: "lightning.LightningModule") -> None:
self.copy_metrics(trainer=trainer, phase=LearningPhase.TEST)
[docs]
def copy_metrics(self, trainer: lightning.Trainer, phase: LearningPhase) -> None:
phase_metrics = copy.deepcopy(trainer.callback_metrics) # Dict[str, torch.Tensor]
phase_metrics = {
name: metric_tensor.item()
for name, metric_tensor in phase_metrics.items()
if name.startswith(str(phase.value))
}
logger.debug(f"{phase=}, logged metrics {phase_metrics.keys()}")
if len(self.metrics) == trainer.current_epoch:
# first time copying for this epoch
self.metrics.append(phase_metrics)
elif len(self.metrics) == trainer.current_epoch + 1:
# updating this epoch, for example adding train to val metrics
self.metrics[trainer.current_epoch].update(phase_metrics)
else:
# we might have missed some epochs?
diff = trainer.current_epoch - len(self.metrics)
logger.error(f"There is a discrepancy of {diff} between metrics recorded and the current epoch!")
for _ in range(diff):
self.metrics.append({})
self.metrics.append(phase_metrics)
[docs]
class MMLRichProgressBar(RichProgressBar):
"""
Slight modification of the Lightning rich progress bar, showing the correct experiment name.
"""
[docs]
def get_metrics(self, trainer, model):
# don't show the version number
items = super().get_metrics(trainer, model)
items.pop("v_num", None)
items["exp"] = "/".join(Path(os.getcwd()).parts[-2:])
return items
[docs]
class MMLTQDMProgressBar(TQDMProgressBar):
"""
Slight modification of the Lightning tqdm progress bar, showing the correct experiment name.
"""
[docs]
def __init__(self, refresh_rate=1):
super().__init__(refresh_rate=refresh_rate)
self.experiment_name = "/".join(Path(os.getcwd()).parts[-2:])
[docs]
def get_metrics(self, trainer, model):
# don't show the version number
items = super().get_metrics(trainer, model)
items.pop("v_num", None)
items["exp"] = self.experiment_name
return items
[docs]
class MMLModelCheckpoint(ModelCheckpoint):
"""
Slight modification of the lightning ModelCheckpoint (see
https://github.com/Lightning-AI/pytorch-lightning/issues/20245).
"""
def _save_last_checkpoint(self, trainer: "lightning.Trainer", monitor_candidates: Dict[str, torch.Tensor]) -> None:
"""Only update last checkpoint in case there has just been a new checkpoint."""
if self._last_global_step_saved == trainer.global_step:
super()._save_last_checkpoint(trainer=trainer, monitor_candidates=monitor_candidates)
[docs]
def on_train_epoch_end(self, trainer: "lightning.Trainer", pl_module: "lightning.LightningModule") -> None:
"""Save a checkpoint at the end of the training epoch."""
if not self._should_skip_saving_checkpoint(trainer) and self._should_save_on_train_epoch_end(trainer):
monitor_candidates = self._monitor_candidates(trainer)
if self._every_n_epochs >= 1 and (trainer.current_epoch + 1) % self._every_n_epochs == 0:
self._save_topk_checkpoint(trainer, monitor_candidates)
self._save_last_checkpoint(trainer, monitor_candidates)
[docs]
def on_validation_end(self, trainer: "lightning.Trainer", pl_module: "lightning.LightningModule") -> None:
"""Save a checkpoint at the end of the validation stage."""
if not self._should_skip_saving_checkpoint(trainer) and not self._should_save_on_train_epoch_end(trainer):
if self._every_n_epochs >= 1 and (trainer.current_epoch + 1) % self._every_n_epochs == 0:
monitor_candidates = self._monitor_candidates(trainer)
self._save_topk_checkpoint(trainer, monitor_candidates)
self._save_last_checkpoint(trainer, monitor_candidates)