Source code for mml.core.models.torch_base

# LICENSE HEADER MANAGED BY add-license-header
#
# SPDX-FileCopyrightText: Copyright 2024 German Cancer Research Center (DKFZ) and contributors.
# SPDX-License-Identifier: MIT
#

from __future__ import annotations

import logging
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Dict, List, Optional, Union

import hydra.utils
import torch

from mml.core.data_loading.task_attributes import RGBInfo, TaskType
from mml.core.data_loading.task_struct import TaskStruct

logger = logging.getLogger(__name__)


[docs] class BaseModel(torch.nn.Module, ABC):
[docs] def __init__(self, **kwargs): """ The base class for MML models. Derived classes must implement the following methods: - :meth:`_init_model` - for backbone initialization - :meth:`_create_head` - for head creation - :meth:`supports` - reporting supported task types - :meth:`forward` - the models forward pass through backbone and all heads - :meth:`forward_features` - alternative usage as feature extractor """ super(BaseModel, self).__init__() # model requirements self.required_mean: Optional[RGBInfo] = None # mean expected by model self.required_std: Optional[RGBInfo] = None # std expected by model self.input_size = (None, None, None) # channel, height, width - will be defined during init # nn modules self.backbone: Union[torch.nn.Module, None] = None self.heads = torch.nn.ModuleDict({}) # for freezing functionality self._frozen_params: List[str] = [] # store init kwargs self._init_kwargs: Dict[str, Any] = kwargs # stores stuff that needs to be persistent when re-initializing self._head_init_kwargs: List[Dict[str, Any]] = [] # stores init kwargs of heads # actually init backbone self._init_model(**kwargs) logger.debug("Model initialised.")
@abstractmethod def _init_model(self, **kwargs: Any) -> None: """ This shall implement the backbone module as well as potentially load pretrained weights thereof. """ raise NotImplementedError @abstractmethod def _create_head(self, task_type: TaskType, num_classes: int, **kwargs: Any) -> BaseHead: """ This shall implement the creation of heads. Given a certain task type the head must be able to be attached to the backbone as implemented by the forward method. """ raise NotImplementedError
[docs] @abstractmethod def supports(self, task_type: TaskType) -> bool: """ Whether the model supports a given task type. :param TaskType task_type: :return: true iff task type is supported by model """ pass
[docs] @abstractmethod def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]: """ Model forward functionality. Passes input through the backbone once and forwards output through each head. :param torch.tensor x: input tensor :return: a dictionary, with one entry per head, key is head name and value is head output """ pass
[docs] @abstractmethod def forward_features(self, x: torch.Tensor) -> torch.Tensor: """ Feature extraction functionality. Only forwards features through the backbone and post-processes them to 1D. :param torch.tensor x: input tensor :return: a 1D tensor """ pass
[docs] def add_head(self, task_struct: TaskStruct, **kwargs: Any) -> None: """ The functionality used to add heads to a model. :param TaskStruct task_struct: struct for the task to add a head for :param Any kwargs: additional kwargs that will be forwarded """ if task_struct.name in self.heads: raise KeyError(f"You cannot register a head with already present name ({task_struct.name}).") if not self.supports(task_type=task_struct.task_type): raise RuntimeError(f"Task type {task_struct.task_type} not supported by model.") init_kwargs = {"task_type": task_struct.task_type, "num_classes": task_struct.num_classes} init_kwargs.update(kwargs) self.heads[task_struct.name] = self._create_head(**init_kwargs) self._head_init_kwargs.append(init_kwargs) logger.debug( f"Added head {task_struct.name} of task type {task_struct.task_type} with " f"{task_struct.num_classes} classes." )
[docs] def count_parameters(self, only_trainable: bool = True) -> Dict[str, int]: """ Gives information on parameter count of the model. :param bool only_trainable: if True, only counts parameters that requires_grad. :return: a dict with component names as key (backbone or name of heads) and parameter count as value """ info_dict = {} for name, module in [("backbone", self.backbone)] + list(self.heads.items()): if module is None: info_dict[name] = 0 else: info_dict[name] = sum(p.numel() for p in module.parameters() if p.requires_grad or not only_trainable) return info_dict
[docs] def freeze_backbone(self) -> None: """ Freezes all backbone parameters. """ for name, par in self.backbone.named_parameters(): # type: ignore[union-attr] if par.requires_grad: par.requires_grad = False self._frozen_params.append(name) logger.debug(f"Froze {len(self._frozen_params)} parameters of model.")
[docs] def unfreeze_backbone(self) -> None: """ Unfreezes previously frozen backbone parameters. """ for name, par in self.backbone.named_parameters(): # type: ignore[union-attr] if name in self._frozen_params: par.requires_grad = True logger.debug(f"Unfroze {len(self._frozen_params)} params of model.") self._frozen_params = []
[docs] @staticmethod def load_checkpoint(param_path: Union[Path, str]) -> "BaseModel": """ Load from a checkpoint. :param Union[Path, str] param_path: path to load checkpoint from :return: """ state = torch.load(param_path, weights_only=False) model: BaseModel = hydra.utils.instantiate(dict(_target_=state["__target__"], **state["__init_kwargs__"])) model.backbone.load_state_dict(state["backbone"]) # type: ignore[union-attr] model._frozen_params = state["__frozen_params__"] for head_name, init_kwargs in zip(state["__head_names__"], state["__head_init_kwargs__"]): head = model._create_head(**init_kwargs) model.heads[head_name] = head head.load_state_dict(state[head_name]) logger.info("Loaded checkpoint!") logger.debug(f"@ {param_path}") return model
[docs] def save_checkpoint(self, param_path: Union[Path, str]) -> None: """ Save a model checkpoint. :param Union[Path, str] param_path: path to store checkpoint :return: """ state = {name: head.state_dict() for name, head in self.heads.items()} state.update( { "backbone": self.backbone.state_dict(), # type: ignore[union-attr] "__head_names__": list(self.heads.keys()), "__init_kwargs__": self._init_kwargs, "__target__": self.__class__, "__frozen_params__": self._frozen_params, "__head_init_kwargs__": self._head_init_kwargs, } ) torch.save(state, param_path) logger.info("Saved checkpoint!") logger.debug(f"@ {param_path}")
[docs] class BaseHead(torch.nn.Module, ABC):
[docs] def __init__(self, task_type: TaskType, num_classes: int, **kwargs: Any): """The base class for MML model heads.""" super(BaseHead, self).__init__() self.task_type = task_type self.num_classes = num_classes
[docs] @abstractmethod def forward(self, x: torch.Tensor) -> torch.Tensor: pass