Source code for mml.core.scripts.schedulers.info_scheduler

# LICENSE HEADER MANAGED BY add-license-header
#
# SPDX-FileCopyrightText: Copyright 2024 German Cancer Research Center (DKFZ) and contributors.
# SPDX-License-Identifier: MIT
#

import logging
from pathlib import Path

import numpy as np
import optuna
import torch
from omegaconf import DictConfig
from PIL import Image
from prettytable.colortable import ColorTable, Themes
from torchvision.utils import make_grid
from tqdm import tqdm

from mml.core.data_loading.task_attributes import DataSplit, Modality, TaskType
from mml.core.data_loading.task_dataset import TaskDataset
from mml.core.scripts.exceptions import MMLMisconfigurationException
from mml.core.scripts.schedulers.base_scheduler import AbstractBaseScheduler
from mml.core.scripts.utils import LearningPhase, throttle_logging

logger = logging.getLogger(__name__)


[docs] class InfoScheduler(AbstractBaseScheduler): """ AbstractBaseScheduler implementation for receiving information on a project. Includes the following subroutines: - tasks (show information on tasks) - hpo (show information on hpo results) - samples (plots sample images for each task (and/or jointly)) - models (show information on existing model descriptions) """
[docs] def __init__(self, cfg: DictConfig): # initialize super(InfoScheduler, self).__init__(cfg=cfg, available_subroutines=["tasks", "hpo", "samples", "models"]) self.total_sample_sum = 0
[docs] def create_routine(self): """ This scheduler implements two subroutines, one for task information and one for hpo study information. :return: None """ # -- add task info commands if "tasks" in self.subroutines: for task in self.cfg.task_list: self.commands.append(self.info_task) self.params.append([task]) # -- add study info commands if "hpo" in self.subroutines: if self.cfg.mode.study_name: # was given a single study self.commands.append(self.info_study) self.params.append([self.cfg.mode.study_name]) else: # show all available studies for this project logger.info("Was given no study name to search for, so showing all studies with project prefix.") study_summaries = optuna.study.get_all_study_summaries(storage=self.cfg.hpo.storage) logger.debug(f"Found {len(study_summaries)} studies in database.") filtered_studies = [ study.study_name for study in study_summaries if str(study.study_name).startswith(self.cfg.proj) ] for study_name in filtered_studies: self.commands.append(self.info_study) self.params.append([study_name]) # -- add sample plotting command if "samples" in self.subroutines: self.commands.append(self.plot_samples) self.params.append([]) # -- add MODEL info command if "models" in self.subroutines: self.commands.append(self.info_models) self.params.append([])
[docs] def after_preparation_hook(self): pass
[docs] def before_finishing_hook(self): logger.info(f"Total number of all samples: {self.total_sample_sum}.")
[docs] def prepare_exp(self) -> None: # no need to prepare tasks if only looking at hpo results if "tasks" in self.subroutines or "sample_grid" in self.subroutines or "models" in self.subroutines: super(InfoScheduler, self).prepare_exp()
[docs] def info_task(self, task_name: str) -> None: logger.info("Starting info on task " + self.highlight_text(task_name)) task_struct = self.get_struct(task_name) logger.info(str(task_struct)) dataset = TaskDataset( root=Path(self.cfg.data_dir) / task_struct.relative_root, split=DataSplit.FULL_TRAIN, fold=0, transform=None ) logger.info(f"Num samples (full train set): {len(dataset)}") self.total_sample_sum += len(dataset) if dataset.task_type == TaskType.CLASSIFICATION: dataset.select_samples(split=DataSplit.VAL, fold=0) class_occ = {} for sample in dataset.samples: _cl = sample[Modality.CLASS] if _cl in class_occ: class_occ[_cl] += 1 else: class_occ[_cl] = 1 logger.info(f"Default validation class occurrences are: {class_occ}") logger.info("Finished info on task " + self.highlight_text(task_name))
[docs] def info_study(self, study_name: str) -> None: logger.info("Starting info on study " + self.highlight_text(study_name)) try: loaded_study = optuna.load_study(study_name=study_name, storage=self.cfg.hpo.storage) except KeyError: raise MMLMisconfigurationException( f"Study {study_name} not found! Adapt mode.study_name config, " f"and check your hpo.storage settings to make studies persistent. " f"See https://docs.sqlalchemy.org/en/20/core/engines.html for details." f"Create the study from scratch with mml ... " f"hydra.sweeper.study_name=YOUR_STUDY_NAME " f"hpo.storage=YOUR_STORAGE_SETTING --multirun." ) df = loaded_study.trials_dataframe() logger.info(f"Detailed study info:\n{df.drop(columns='number').to_string()}") logger.info(f"Statistics of study:\n{df.drop(columns=['number', 'duration']).describe().T.to_string()}") top = ( df.drop( columns=[ "number", "duration", "datetime_start", "datetime_complete", "state", "system_attrs_search_space", ] ) .sort_values(by="value", ascending=self.cfg.hpo.direction == "minimize", ignore_index=True) .head(5) .T.to_string() ) logger.info(f"Top entries:\n{top}") if len(loaded_study.get_trials(states=(optuna.trial.TrialState.COMPLETE,))) > 0: best_params = loaded_study.best_params best_val = loaded_study.best_value logger.info(f"Best params: {best_params}") logger.info(f"Best value: {best_val}") else: logger.info("No complete trial corresponding to this study has been found.") logger.info("Finished info on study " + self.highlight_text(study_name))
[docs] def plot_samples(self) -> None: logger.info("Starting plotting sample grid of all tasks.") if len(self.cfg.task_list) == 0: logger.error("No tasks selected to show samples from.") return img_samples = {} # sort by sample size task_sizes = {struct.name: sum(struct.class_occ.values()) for struct in self.task_factory.container} task_list = sorted(task_sizes.items(), key=lambda x: x[1]) for task_name in tqdm([t[0] for t in task_list], desc="Loading samples"): task_struct = self.get_struct(task_name=task_name) datamodule = self.create_datamodule(task_structs=task_struct) # suppress warning of non-normalized image usage with throttle_logging(level=logging.WARNING, package="mml.core.data_loading.lightning_datamodule"): datamodule.setup(stage="fit") img_samples[task_name] = datamodule.task_datasets[task_name][LearningPhase.TRAIN][-1][Modality.IMAGE] * 255 if self.cfg.mode.info_grid: # gridplot proportion = 11 / 5 # width to height proportion grid = make_grid(tensor=list(img_samples.values()), nrow=int(np.sqrt(len(img_samples) * proportion))) path = self.fm.construct_saving_path(obj=grid, key="sample_grid") ndarr = grid.permute(1, 2, 0).to("cpu", torch.uint8).numpy() im = Image.fromarray(ndarr) im.save(path) logger.info(f"Finished plotting sample grid. Can be found at {path}.") if self.cfg.mode.info_individual: # individual samples plots for task_name, img_tensor in img_samples.items(): path = self.fm.construct_saving_path(obj=img_tensor, key="img_examples", task_name=task_name) ndarr = img_tensor.permute(1, 2, 0).to("cpu", torch.uint8).numpy() im = Image.fromarray(ndarr) im.save(path) logger.info("Finished plotting individual samples for each task.")
[docs] def info_models(self): if len(self.cfg.task_list) == 0: logger.error("No tasks selected to show models from.") return logger.info( "Model info shows only loaded models (not all existing!). Use reuse.models=... to select a " "project to load models from." ) table = ColorTable(theme=Themes.OCEAN) table.field_names = ["task", "created", "fold", "performance", "training (secs)", "params?", "preds?"] no_model_tasks = [] for task in self.cfg.task_list: task_models = self.get_struct(task).models if len(task_models) == 0: no_model_tasks.append(task) continue for ix, model in enumerate(task_models): table.add_row( [ model.task, model.created, model.fold, f"{model.performance:.8f}" if model.performance else "x", int(model.training_time), "y" if model.parameters and model.parameters.exists() else "n", len(model.predictions), ], divider=ix + 1 == len(task_models), ) print(table) logger.info(f"No models found for {len(no_model_tasks)} tasks ({no_model_tasks}).")