# LICENSE HEADER MANAGED BY add-license-header
#
# SPDX-FileCopyrightText: Copyright 2024 German Cancer Research Center (DKFZ) and contributors.
# SPDX-License-Identifier: MIT
#
import logging
import random
import warnings
from typing import List, Optional
import lightning
import numpy as np
import torch
from omegaconf import DictConfig, ListConfig
from mml.core.data_loading.task_attributes import DataSplit
from mml.core.scripts.exceptions import MMLMisconfigurationException
from mml.core.scripts.model_storage import ModelStorage
from mml.core.scripts.pipeline_configuration import PipelineCfg
from mml.core.scripts.schedulers.base_scheduler import AbstractBaseScheduler
from mml.core.scripts.utils import ARG_SEP, TAG_SEP, LearningPhase, catch_time
logger = logging.getLogger(__name__)
[docs]
class TrainingScheduler(AbstractBaseScheduler):
"""
New version of the former "optimization" scheduler. Supports the following features:
- model training
- model prediction
- model testing
In addition to the standard hooks (after_preparation_hook, before_finishing_hook) it provides additional hooks that
may be overridden by inheriting schedulers:
- before_training_hook
- after_training_hook
It further allows for task nesting and cross validation.
"""
[docs]
def __init__(self, cfg: DictConfig):
# check compliance to new train scheduler behaviour
if not cfg.pivot.name:
raise MMLMisconfigurationException(
"Train mode (and all inherited ones) requires a pivot task to be set from mml version 0.12.0 onwards."
)
if ("nested" in cfg.pivot.name or "nested" in cfg.pivot.tags) and cfg.mode.nested:
warnings.warn(
"TrainingScheduler takes care of task nesting itself. Currently you are introducing DOUBLE "
"nesting by setting mode.nested=true and choosing an already nested pivot task."
)
# initialize
self.n_folds: int = -1 # will be set during create_routine
super(TrainingScheduler, self).__init__(cfg=cfg, available_subroutines=["train", "predict", "test"])
self.monitored_performances: List[float] = [] # final loss values observed during validation
# interpretation and checking for multitask options
self.co_tasks: List[str] = []
if self.cfg.mode.multitask:
if self.cfg.mode.multitask == 1:
raise MMLMisconfigurationException(
"To enable multitask learning set mode.multitask to the TOTAL number of task to be learned jointly."
)
if self.cfg.mode.co_tasks == "random":
choices = [task for task in self.cfg.task_list if not task.startswith(self.pivot)]
if len(choices) < self.cfg.mode.multitask - 1:
raise MMLMisconfigurationException("Available tasks are not enough to support sufficient co-tasks.")
self.co_tasks = random.sample(population=choices, k=self.cfg.mode.multitask - 1)
logger.info(f"Randomly selected co-learning tasks: {self.co_tasks}")
else:
# was given explicit co tasks, check if compatible
if any(task not in self.cfg.task_list for task in self.cfg.mode.co_tasks):
raise MMLMisconfigurationException("Co task not present in cfg.task_list!")
if len(set(self.cfg.mode.co_tasks)) != len(self.cfg.mode.co_tasks):
raise MMLMisconfigurationException(
"mode.co_tasks supports identical co-tasks only if "
"you use the identity tag (task_name+identity) to "
"create a virtual duplicate."
)
if any(task.startswith(self.pivot) for task in self.cfg.mode.co_tasks):
raise MMLMisconfigurationException("Should not use pivot (or derivative) as co-task.")
self.co_tasks = self.cfg.mode.co_tasks
logger.info(f"Configured co-learning tasks: {self.co_tasks}")
# these are some information leakage backup checks
if not self.cfg.mode.nested and self.cfg.mode.cv:
# check lr scheduler leakage
if (
self.cfg.lr_scheduler["_target_"] is not None
and "ReduceLROnPlateau" in self.cfg.lr_scheduler["_target_"]
):
raise MMLMisconfigurationException(
"Using ReduceLROnPlateau LR-scheduler without activating mode.nested=true leads to "
"information leakage from val split on training and should be avoided. Either activate"
" nesting or change the LR-scheduler, e.g. with <lr_scheduler=none>."
)
# check early stopping leakage
for cb in self.cfg.cbs.values():
if "EarlyStopping" in cb["_target_"] and "val" in cb["monitor"]:
raise MMLMisconfigurationException(
"Using EarlyStopping callback without activating mode.nested=true leads to"
"information leakage from val split on training and should be avoided. Either deactivate this "
"callback (e.g. <callbacks=none>), choose a different monitor value (not depending on the val "
"split) or activate nesting."
)
# checks regarding the storing of model parameters
if not self.cfg.mode.store_parameters and "train" in self.subroutines and "predict" in self.subroutines:
raise MMLMisconfigurationException("Predictions after training require mode.store_parameters=True!")
if (
self.cfg.mode.store_parameters
and "check_val_every_n_epoch" in self.cfg.trainer
and self.cfg.trainer.check_val_every_n_epoch > self.cfg.trainer.max_epochs
):
raise MMLMisconfigurationException(
f"It seems like you only validate every {self.cfg.trainer.check_val_every_n_epoch} epochs, but "
f'only train for max {self.cfg.trainer.max_epochs} although requested "mode.store_parameters".'
)
if self.cfg.mode.store_parameters and self.cfg.mode.cv:
warnings.warn(
f"Cross-Validation will store {self.n_folds} model parameters. To reduce memory consumption "
f"you may consider either setting mode.store_parameters=false (which will omit storing the "
f"model parameters) or reuse.clean_up.parameters=true (which deletes the model parameters "
f"at the end of the experiment."
)
# more checks
if self.cfg.mode.cv and "test" in self.subroutines:
warnings.warn(
"Chose both cross validation and testing (on hold out test set). Note that only one CV model "
"will be evaluated!"
)
if not self.cfg.mode.nested and "test" in self.subroutines:
warnings.warn(
"You are testing on the `actual` test set! To ensure unbiased fair evaluation this should only "
"be done on the very end of model development."
"You may chose mode.nested=true so the testing subroutine will be performed NOT on the (potential) "
"official task test split, but on the hold-out fold."
)
if self.cfg.mode.eval_on:
if not (isinstance(self.cfg.mode.eval_on, list) or isinstance(self.cfg.mode.eval_on, ListConfig)):
raise MMLMisconfigurationException("Must provide mode.eval_on as list of tasks, gave.")
if any(t not in self.cfg.task_list for t in self.cfg.mode.eval_on):
raise MMLMisconfigurationException(
f"Chose to evaluate on {self.cfg.mode.eval_on} but one of these tasks is not given in the task_list"
)
if (
self.cfg.reuse.models
and "train" in self.subroutines
and any(sub in self.subroutines for sub in ["test", "predict"])
):
raise MMLMisconfigurationException(
"Reusing existing models combined with training. This may lead to undetermined behaviour during "
"testing/predicting"
)
[docs]
def create_routine(self):
"""
This scheduler implements three sub-routines, training, testing and prediction.
The routine takes care of cross validation and nesting.
"""
# calculate the number of available pivot folds
try:
pivot_description = self.fm.load_task_description(
self.fm.data_path / self.fm.task_index[self.pivot]["none"]
)
except KeyError:
raise RuntimeError(
f"Task {self.pivot} not found in task index. You may need to call info or create mode before passing "
f"as pivot task to train."
)
self.n_folds = len(pivot_description.train_folds)
# derive folds to loop over
if self.cfg.mode.cv:
folds = list(range(self.n_folds))
else:
folds = [0]
# adapt task list, depending on the nesting and cv mode configurations, happens before prepare_exp struct
# construction, but after the modifications of tasks and pivot by the auto tagging tasks options
if self.cfg.mode.nested:
for fold in folds:
self.cfg.task_list.append(f"{self.pivot}{TAG_SEP}nested{ARG_SEP}{fold}")
eval_tasks = self.cfg.mode.eval_on or [None] # None means predict on self
# -- add training commands
if "train" in self.subroutines:
for fold in folds:
self.commands.append(self.train_fold)
if self.cfg.mode.nested:
self.params.append([f"{self.pivot}{TAG_SEP}nested{ARG_SEP}{fold}", 0])
else:
self.params.append([self.pivot, fold])
if "predict" in self.subroutines:
# predicts on the test split (which is original val split for nested tasks)
for fold in folds:
for eval_on in eval_tasks:
self.commands.append(self.predict_fold)
if self.cfg.mode.nested:
self.params.append([f"{self.pivot}{TAG_SEP}nested{ARG_SEP}{fold}", 0, eval_on])
# also add predictions on test set of original pivot (not nested) - to be used in postprocessing
if eval_on is None:
self.commands.append(self.predict_fold)
self.params.append([f"{self.pivot}{TAG_SEP}nested{ARG_SEP}{fold}", 0, self.pivot])
else:
self.params.append([self.pivot, fold, eval_on])
if "test" in self.subroutines:
for eval_on in eval_tasks:
self.commands.append(self.test_task)
if self.cfg.mode.nested:
self.params.append([f"{self.pivot}{TAG_SEP}nested{ARG_SEP}0", eval_on])
else:
self.params.append([self.pivot, eval_on])
[docs]
def after_preparation_hook(self):
if self.cfg.mode.eval_on:
# compare pivot and eval tasks for compatibility
for eval_task in self.cfg.mode.eval_on:
pivot_struct = self.get_struct(self.pivot)
eval_struct = self.get_struct(eval_task)
if pivot_struct.task_type != eval_struct.task_type:
raise MMLMisconfigurationException(
f"Invalid task type for evaluation! Pivot task has type "
f"{pivot_struct.task_type} but evaluation task {eval_task} has type "
f"{eval_struct.task_type}."
)
if pivot_struct.num_classes != eval_struct.num_classes:
raise MMLMisconfigurationException(
f"Invalid number of classes for evaluation! Pivot task has "
f"{pivot_struct.num_classes} classes but evaluation {eval_task} task has "
f"{eval_struct.num_classes} classes."
)
[docs]
def before_finishing_hook(self):
# return the task loss (averaged over folds if mode.cv is active)
self.return_value = np.mean(self.monitored_performances)
# gather further metrics on training
with open(self.planned_schedule, "r") as schedule_file:
planned_schedule = schedule_file.readlines()
train_runs = [line for line in planned_schedule if self.train_fold.__name__ in line]
# if neither predict nor test are applied we want to show validation results
if len(train_runs) > 0 and "test" not in self.subroutines and "predict" not in self.subroutines:
args = []
for line in train_runs:
# store args as tuples task_name, fold in a list for all calls
args.append(elem.strip(" '") for elem in line.split("/")[-1].strip(" []\n").split(","))
# for each of the trained models we will evaluate the validation if test
aggregated_metrics = {}
logger.info(f"Will try to aggregate validation results over {len(args)} training runs.")
for task_name, fold in args:
struct = self.get_struct(task_name)
model_candidates = [model for model in struct.models if model.fold == int(fold)]
if len(model_candidates) != 1:
# this can happen if reuse was used beforehand
logger.error(
f"Ambiguous model choices for {task_name} and {fold}! Will skip while aggregating results."
)
continue
model = model_candidates[0]
# check if recorded metrics have validation entry
val_idxs = [
idx
for idx, metric_dict in enumerate(model.metrics)
if any(k.startswith(LearningPhase.VAL) for k in metric_dict)
]
if len(val_idxs) == 0:
logger.error(
f"No validation metrics for task {task_name} and fold {fold}! Will skip while "
f"aggregating results."
)
continue
for metric_name, metric_value in model.metrics[val_idxs[-1]].items():
if metric_name in aggregated_metrics:
aggregated_metrics[metric_name].append(metric_value)
else:
aggregated_metrics[metric_name] = [metric_value]
# compute stats and show
if len(aggregated_metrics) == 0:
logger.error("No validation metrics found!")
else:
logger.info("Aggregated validation results over training:")
for metric, values in aggregated_metrics.items():
logger.info(f"{metric} : {np.mean(values):.2f} ± {np.std(values):.2f}")
[docs]
def before_training_hook(
self,
datamodule: lightning.LightningDataModule,
model: lightning.LightningModule,
trainer: lightning.Trainer,
fold: int,
task_name: str,
) -> None:
"""
This hook allows of setup modification before the model fitting starts (and also before lightning tuning).
Allows to modify task_weights, data, trainer callbacks, etc. May be overwritten as part of inheriting from
TrainScheduler.
:param lightning.LightningDataModule datamodule: the prepared datamodule (no setup run yet)
:param lightning.LightningModule model: the prepared model
:param lightning.Trainer trainer: the prepared trainer
:param int fold: the current fold
:param str task_name: the current task
:return: None
"""
pass
[docs]
def after_training_hook(
self,
datamodule: lightning.LightningDataModule,
model: lightning.LightningModule,
trainer: lightning.Trainer,
fold: int,
task_name: str,
) -> None:
"""
This hook allows of setup modification after the model fitting ended (and potential lightning tuning).
Allows to modify task_weights, data, trainer callbacks, etc. May be overwritten as part of inheriting from
TrainScheduler.
:param lightning.LightningDataModule datamodule: the datamodule used
:param lightning.LightningModule model: the trained model
:param lightning.Trainer trainer: the used trainer
:param int fold: the used fold
:param str task_name: the pivot task
:return: None
"""
pass
[docs]
def train_fold(self, task_name: str, fold: int) -> None:
logger.info("Starting training for task " + self.highlight_text(task_name) + f" and fold {fold}.")
pivot_struct = self.get_struct(task_name)
co_structs = [self.get_struct(task_name=task_name) for task_name in self.co_tasks]
if self.cfg.mode.use_blueprint:
if "blueprint" in pivot_struct.paths:
pipeline = PipelineCfg.load(
path=pivot_struct.paths["blueprint"], pipeline_keys=self.cfg.mode.pipeline_keys
)
logger.info(f"Found blueprint pipeline for task {task_name}, will evaluate that.")
else:
raise RuntimeError(f"Was not able to find appropriate blueprint pipeline for task {task_name}!")
else:
pipeline = PipelineCfg.from_cfg(current_cfg=self.cfg, restrict_keys=self.cfg.mode.pipeline_keys)
with pipeline.activate(current_cfg=self.cfg):
# preparation
datamodule = self.create_datamodule(task_structs=[pivot_struct] + co_structs, fold=fold)
module = self.create_model(
task_structs=[pivot_struct] + co_structs, task_weights=self.cfg.mode.task_weights
)
module.train() # see https://github.com/Lightning-AI/pytorch-lightning/releases/tag/2.2.0
trainer = self.create_trainer(
monitor=(f"val/{task_name}/loss", "min") if self.cfg.mode.store_best else None, metrics_callback=True
)
self.before_training_hook(
datamodule=datamodule, model=module, trainer=trainer, fold=fold, task_name=task_name
)
# tuning and training
with catch_time() as training_timer:
self.lightning_tune(trainer=trainer, model=module, datamodule=datamodule)
trainer.fit(model=module, datamodule=datamodule)
self.after_training_hook(
datamodule=datamodule, model=module, trainer=trainer, fold=fold, task_name=task_name
)
# create another pipeline from the current one (within blueprint keys activated and without restrictions)
# to ensures storing the full superset of configuration from a potential partially masked blueprint training
pipeline_path = PipelineCfg.from_cfg(current_cfg=self.cfg).store(
task_struct=pivot_struct, as_blueprint=False
)
# output processing
if self.cfg.mode.store_best:
if self.checkpoint_callback.best_model_score is None:
best_score = 1000 # catch fast_dev_run
else:
best_score = self.checkpoint_callback.best_model_score.item()
else:
try:
best_score = self.metrics_callback.metrics[-2][f"val/{task_name}/loss"]
except KeyError:
raise RuntimeError(
'Unable to find "val/{task_name}/loss" in recorded metrics of the last epoch,'
"make sure to activate validation with lightning trainer."
)
self.monitored_performances.append(best_score)
parameters_path = self.fm.construct_saving_path(module, key="parameters", task_name=pivot_struct.name)
if self.cfg.mode.store_parameters:
# determine the correct parameters directory
cpt_path = (
self.checkpoint_callback.best_model_path
if self.cfg.mode.store_best
else self.checkpoint_callback.last_model_path
)
# load these weights
state_dict = torch.load(cpt_path, map_location=torch.device("cpu"), weights_only=False)["state_dict"]
module.load_state_dict(state_dict)
# store model
module.model.save_checkpoint(param_path=parameters_path)
else:
logger.info("mode.store_parameters is set false, so no parameters will be stored!")
storage = ModelStorage(
pipeline=pipeline_path,
parameters=parameters_path,
fold=fold,
task=task_name,
performance=best_score,
metrics=self.metrics_callback.metrics,
training_time=training_timer.elapsed,
)
storage.store(task_struct=pivot_struct, fold=fold)
pivot_struct.models.append(storage)
logger.info("Finished training for task " + self.highlight_text(task_name) + f" and fold {fold}.")
[docs]
def predict_fold(self, task_name: str, fold: int, eval_on: Optional[str] = None) -> None:
logger.info("Starting predicting for task " + self.highlight_text(task_name) + f" and fold {fold}.")
task_struct = self.get_struct(task_name)
# find model storage
choices = [storage for storage in task_struct.models if storage.fold == fold]
if len(choices) == 0:
raise RuntimeError(f"Did not find any existing model storage for task {task_name} and fold {fold}.")
# sort ascending
choices.sort(key=lambda x: x.created)
storage = choices[-1]
logger.info(f"Found {len(choices)} matching model storages, used the latest from {storage.created}.")
# check preprossing compatibility
original_pipeline = PipelineCfg.load(path=storage.pipeline)
if original_pipeline.pipeline_cfg.preprocessing.id != self.cfg.preprocessing.id:
warnings.warn(
f"Current preprocessing is {self.cfg.preprocessing.id} but the loaded model was originally "
f"preprocessed as {original_pipeline.pipeline_cfg.preprocessing.id}. MML will try to continue "
f"with the given preprocessing pipeline."
)
# prepare model
module = self.create_model(task_structs=[task_struct], load_parameters=storage.parameters)
if eval_on and eval_on != task_name:
eval_task = self.get_struct(eval_on)
eval_task_name = eval_on
logger.info("Will predict on task " + self.highlight_text(eval_task_name) + "!")
module.setup_redirection(head=task_name, task=eval_task_name)
else:
eval_task_name = task_name
eval_task = task_struct
# prepare data and trainer
datamodule = self.create_datamodule(task_structs=eval_task, fold=fold)
trainer = self.create_trainer()
# perform predictions
split_batched_predictions = {}
with catch_time() as predict_timer:
for split in [DataSplit.TEST, DataSplit.VAL, DataSplit.UNLABELLED]:
# switch prediction split
logger.info(f"Predicting split {split.name}.")
datamodule.predict_on = split
split_batched_predictions[split] = trainer.predict(
model=module, dataloaders=datamodule, return_predictions=True
)
if split_batched_predictions[split] is not None:
logger.info(f"Predicted {len(split_batched_predictions[split])} batches for split {split.name}.")
logger.debug(f"Prediction time was {predict_timer.pretty_time}.")
# reformat predictions as dict -> image_id : prediction for each data split and combine them
split_unbatched_predictions = {}
for data_split, pred_dict_list in split_batched_predictions.items():
split_unbatched_predictions[data_split.name] = []
if pred_dict_list is None:
warnings.warn(f"No predictions found for {data_split}!")
continue
for batch in pred_dict_list:
for sample_idx in range(batch[eval_task_name]["logits"].size(0)):
predict_dict = {"logits": batch[eval_task_name]["logits"][sample_idx]}
if batch[eval_task_name]["targets"] is not None:
predict_dict.update({"target": batch[eval_task_name]["targets"][sample_idx]})
if batch[eval_task_name]["sample_ids"] is not None:
predict_dict.update({"sample_id": batch[eval_task_name]["sample_ids"][sample_idx]})
split_unbatched_predictions[data_split.name].append(predict_dict)
preds_path = self.fm.construct_saving_path(
split_unbatched_predictions, key="predictions", task_name=eval_task.name, file_name=f"preds-fold-{fold}.pt"
)
torch.save(split_unbatched_predictions, preds_path)
storage.predictions[eval_task_name] = preds_path
storage.store()
logger.info("Finished predicting for task " + self.highlight_text(task_name) + f" and fold {fold}.")
[docs]
def test_task(self, task_name: str, eval_on: Optional[str] = None) -> None:
logger.info("Starting testing for task " + self.highlight_text(task_name))
task_struct = self.get_struct(task_name)
# find model storage
if len(task_struct.models) == 0:
raise RuntimeError(f"Did not find any existing model storage for task {task_name}.")
# sort ascending
choices = sorted(task_struct.models, key=lambda x: x.created)
storage = choices[-1]
logger.info(f"Found {len(choices)} matching model storages, used the latest from {storage.created}.")
# check preprossing compatibility
original_pipeline = PipelineCfg.load(path=storage.pipeline)
if original_pipeline.pipeline_cfg.preprocessing.id != self.cfg.preprocessing.id:
warnings.warn(
f"Current preprocessing is {self.cfg.preprocessing.id} but the loaded model was originally "
f"preprocessed as {original_pipeline.pipeline_cfg.preprocessing.id}. MML will try to continue "
f"with the given preprocessing pipeline."
)
# prepare model
module = self.create_model(task_structs=[task_struct], load_parameters=storage.parameters)
if eval_on and eval_on != task_name:
eval_task = self.get_struct(eval_on)
eval_task_name = eval_on
logger.info("Will test on task" + self.highlight_text(eval_task_name) + "!")
module.setup_redirection(head=task_name, task=eval_task_name)
else:
eval_task = task_struct
# prepare data and trainer
datamodule = self.create_datamodule(task_structs=eval_task)
trainer = self.create_trainer(metrics_callback=True)
# run the testing
with catch_time() as test_timer:
trainer.test(model=module, datamodule=datamodule)
logger.debug(f"Testing time was {test_timer.pretty_time}.")
storage.metrics += self.metrics_callback.metrics
storage.store()
logger.info(f"Results: {self.metrics_callback.metrics}")
logger.info("Finished testing for task " + self.highlight_text(task_name))