# LICENSE HEADER MANAGED BY add-license-header
#
# SPDX-FileCopyrightText: Copyright 2024 German Cancer Research Center (DKFZ) and contributors.
# SPDX-License-Identifier: MIT
#
import copy
import logging
import os
from pathlib import Path
from typing import Any, Dict, List
import lightning
import lightning.pytorch.tuner.lr_finder
import torch
from lightning.pytorch.callbacks import ModelCheckpoint, RichProgressBar, TQDMProgressBar
from mml.core.scripts.utils import LearningPhase
logger = logging.getLogger(__name__)
[docs]
class StopAfterKeyboardInterrupt(lightning.Callback):
"""
Ensures pytorch lightning to really shut down after keyboard interrupt. This is the new variant of this Callback
aimed to be used by most recent pytorch lightning versions.
"""
[docs]
def on_exception(
self, trainer: "lightning.Trainer", pl_module: "lightning.LightningModule", exception: BaseException
) -> None:
if trainer.interrupted and isinstance(exception, KeyboardInterrupt):
raise InterruptedError(
"Trainer has been interrupted by keyboard! "
"Will stop running MML - no graceful shutdown, ongoing epoch results are lost! "
"Run MML in continue mode to start from the checkpoint of last epochs end."
)
[docs]
class MetricsTrackerCallback(lightning.Callback):
"""
Keeps track of all metrics, at the end of each epoch.
"""
[docs]
def __init__(self):
self.metrics: List[Dict[str, float]] = []
[docs]
def state_dict(self) -> Dict[str, Any]:
return {"metrics": self.metrics}
[docs]
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
self.metrics = state_dict["metrics"]
[docs]
def on_train_epoch_end(self, trainer: "lightning.Trainer", pl_module: "lightning.LightningModule") -> None:
self.copy_metrics(trainer=trainer, phase=LearningPhase.TRAIN)
[docs]
def on_validation_epoch_end(self, trainer: "lightning.Trainer", pl_module: "lightning.LightningModule") -> None:
self.copy_metrics(trainer=trainer, phase=LearningPhase.VAL)
# we need to gather test metrics only at the end (after module.test_epoch_end) to wait for bootstrapped computation
[docs]
def on_test_end(self, trainer: "lightning.Trainer", pl_module: "lightning.LightningModule") -> None:
self.copy_metrics(trainer=trainer, phase=LearningPhase.TEST)
[docs]
def copy_metrics(self, trainer: lightning.Trainer, phase: LearningPhase) -> None:
phase_metrics = copy.deepcopy(trainer.callback_metrics) # Dict[str, torch.Tensor]
phase_metrics = {
name: metric_tensor.item()
for name, metric_tensor in phase_metrics.items()
if name.startswith(str(phase.value))
}
logger.debug(f"{phase=}, logged metrics {phase_metrics.keys()}")
if len(self.metrics) == trainer.current_epoch:
# first time copying for this epoch
self.metrics.append(phase_metrics)
elif len(self.metrics) == trainer.current_epoch + 1:
# updating this epoch, for example adding train to val metrics
self.metrics[trainer.current_epoch].update(phase_metrics)
else:
# we might have missed some epochs?
diff = trainer.current_epoch - len(self.metrics)
logger.error(f"There is a discrepancy of {diff} between metrics recorded and the current epoch!")
for _ in range(diff):
self.metrics.append({})
self.metrics.append(phase_metrics)
[docs]
class MMLRichProgressBar(RichProgressBar):
"""
Slight modification of the Lightning rich progress bar, showing the correct experiment name.
"""
[docs]
def get_metrics(self, trainer, model):
# don't show the version number
items = super().get_metrics(trainer, model)
items.pop("v_num", None)
items["exp"] = "/".join(Path(os.getcwd()).parts[-2:])
return items
[docs]
class MMLTQDMProgressBar(TQDMProgressBar):
"""
Slight modification of the Lightning tqdm progress bar, showing the correct experiment name.
"""
[docs]
def __init__(self, refresh_rate=1):
super().__init__(refresh_rate=refresh_rate)
self.experiment_name = "/".join(Path(os.getcwd()).parts[-2:])
[docs]
def get_metrics(self, trainer, model):
# don't show the version number
items = super().get_metrics(trainer, model)
items.pop("v_num", None)
items["exp"] = self.experiment_name
return items
[docs]
class MMLModelCheckpoint(ModelCheckpoint):
"""
Slight modification of the lightning ModelCheckpoint (see
https://github.com/Lightning-AI/pytorch-lightning/issues/20245).
"""
def _save_last_checkpoint(self, trainer: "lightning.Trainer", monitor_candidates: Dict[str, torch.Tensor]) -> None:
"""Only update last checkpoint in case there has just been a new checkpoint."""
if self._last_global_step_saved == trainer.global_step:
super()._save_last_checkpoint(trainer=trainer, monitor_candidates=monitor_candidates)
[docs]
def on_train_epoch_end(self, trainer: "lightning.Trainer", pl_module: "lightning.LightningModule") -> None:
"""Save a checkpoint at the end of the training epoch."""
if not self._should_skip_saving_checkpoint(trainer) and self._should_save_on_train_epoch_end(trainer):
monitor_candidates = self._monitor_candidates(trainer)
if self._every_n_epochs >= 1 and (trainer.current_epoch + 1) % self._every_n_epochs == 0:
self._save_topk_checkpoint(trainer, monitor_candidates)
self._save_last_checkpoint(trainer, monitor_candidates)
[docs]
def on_validation_end(self, trainer: "lightning.Trainer", pl_module: "lightning.LightningModule") -> None:
"""Save a checkpoint at the end of the validation stage."""
if not self._should_skip_saving_checkpoint(trainer) and not self._should_save_on_train_epoch_end(trainer):
if self._every_n_epochs >= 1 and (trainer.current_epoch + 1) % self._every_n_epochs == 0:
monitor_candidates = self._monitor_candidates(trainer)
self._save_topk_checkpoint(trainer, monitor_candidates)
self._save_last_checkpoint(trainer, monitor_candidates)