# LICENSE HEADER MANAGED BY add-license-header
#
# SPDX-FileCopyrightText: Copyright 2024 German Cancer Research Center (DKFZ) and contributors.
# SPDX-License-Identifier: MIT
#
from __future__ import annotations
import logging
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
import hydra.utils
import peft
import torch
from humanize import intword
from peft import LoraConfig
from peft.config import PeftConfig
from sqlalchemy.testing.plugin.plugin_base import warnings
from transformers import Conv1D
from mml.core.data_loading.task_attributes import RGBInfo, TaskType
from mml.core.data_loading.task_struct import TaskStruct
from mml.core.scripts.decorators import beta
from mml.core.scripts.exceptions import MMLMisconfigurationException
logger = logging.getLogger(__name__)
[docs]
class BaseModel(torch.nn.Module, ABC):
[docs]
def __init__(self, **kwargs):
"""
The base class for MML models. Derived classes must implement the following methods:
- :meth:`_init_model` - for backbone initialization
- :meth:`_create_head` - for head creation
- :meth:`supports` - reporting supported task types
- :meth:`forward` - the models forward pass through backbone and all heads
- :meth:`forward_features` - alternative usage as feature extractor
"""
super(BaseModel, self).__init__()
# model requirements
self.required_mean: Optional[RGBInfo] = None # mean expected by model
self.required_std: Optional[RGBInfo] = None # std expected by model
self.input_size = (None, None, None) # channel, height, width - will be defined during init
# nn modules
self.backbone: Union[torch.nn.Module, None] = None
self.heads = torch.nn.ModuleDict({})
# for freezing functionality
self._frozen_params: List[str] = []
# store init kwargs
self._init_kwargs: Dict[str, Any] = kwargs # stores stuff that needs to be persistent when re-initializing
self._head_init_kwargs: List[Dict[str, Any]] = [] # stores init kwargs of heads
self._peft_kwargs: Dict[str, Any] = {} # stores any peft kwargs
# actually init backbone
self._init_model(**kwargs)
logger.debug("Model initialised.")
@abstractmethod
def _init_model(self, **kwargs: Any) -> None:
"""
This shall implement the backbone module as well as potentially load pretrained weights thereof.
"""
raise NotImplementedError
@abstractmethod
def _create_head(self, task_type: TaskType, num_classes: int, **kwargs: Any) -> BaseHead:
"""
This shall implement the creation of heads. Given a certain task type the head must be able to be attached
to the backbone as implemented by the forward method.
"""
raise NotImplementedError
[docs]
@abstractmethod
def supports(self, task_type: TaskType) -> bool:
"""
Whether the model supports a given task type.
:param TaskType task_type:
:return: true iff task type is supported by model
"""
pass
[docs]
@abstractmethod
def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
"""
Model forward functionality. Passes input through the backbone once and forwards output through each head.
:param torch.tensor x: input tensor
:return: a dictionary, with one entry per head, key is head name and value is head output
"""
pass
[docs]
@abstractmethod
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
"""
Feature extraction functionality. Only forwards features through the backbone and post-processes them to 1D.
:param torch.tensor x: input tensor
:return: a 1D tensor
"""
pass
[docs]
def add_head(self, task_struct: TaskStruct, **kwargs: Any) -> None:
"""
The functionality used to add heads to a model.
:param TaskStruct task_struct: struct for the task to add a head for
:param Any kwargs: additional kwargs that will be forwarded
"""
if task_struct.name in self.heads:
raise KeyError(f"You cannot register a head with already present name ({task_struct.name}).")
if not self.supports(task_type=task_struct.task_type):
raise RuntimeError(f"Task type {task_struct.task_type} not supported by model.")
init_kwargs = {"task_type": task_struct.task_type, "num_classes": task_struct.num_classes}
init_kwargs.update(kwargs)
self.heads[task_struct.name] = self._create_head(**init_kwargs)
self._head_init_kwargs.append(init_kwargs)
logger.debug(
f"Added head {task_struct.name} of task type {task_struct.task_type} with "
f"{task_struct.num_classes} classes."
)
[docs]
def count_parameters(self, only_trainable: bool = True) -> Dict[str, int]:
"""
Gives information on parameter count of the model.
:param bool only_trainable: if True, only counts parameters that requires_grad.
:return: a dict with component names as key (backbone or name of heads) and parameter count as value
"""
info_dict = {}
for name, module in [("backbone", self.backbone)] + list(self.heads.items()):
if module is None:
info_dict[name] = 0
else:
info_dict[name] = sum(p.numel() for p in module.parameters() if p.requires_grad or not only_trainable)
return info_dict
[docs]
def freeze_backbone(self) -> None:
"""
Freezes all backbone parameters.
"""
for name, par in self.backbone.named_parameters(): # type: ignore[union-attr]
if par.requires_grad:
par.requires_grad = False
self._frozen_params.append(name)
logger.debug(f"Froze {len(self._frozen_params)} parameters of model.")
[docs]
def unfreeze_backbone(self) -> None:
"""
Unfreezes previously frozen backbone parameters.
"""
for name, par in self.backbone.named_parameters(): # type: ignore[union-attr]
if name in self._frozen_params:
par.requires_grad = True
logger.debug(f"Unfroze {len(self._frozen_params)} params of model.")
self._frozen_params = []
[docs]
@beta("PEFT integration is still in beta.")
def set_peft(self, peft_cfg: PeftConfig) -> None:
"""
Applies a PEFT (Parameter Efficient FineTuning) method to the model. Usually this will lead to adapters injected
to the base model that complement existing weights. The advantage is that the majority of existing weights is
frozen (the .requires_grad attribute of the tensors is set to false) while only the smaller adapters are kept
trainable.
:param PeftConfig peft_cfg: PEFTConfig instance, see
`huggingface/peft <https://github.com/huggingface/peft/tree/main>`_
:return: None, since model is mofified in place
"""
if self._peft_kwargs:
raise RuntimeError("PEFT already set for this model!")
if self._frozen_params:
warnings.warn(
"Backbone was frozen prior to applying PEFT, will first unfreeze backbone and then apply."
"You may re-freeze the backbone (i.e. the injected adapters)."
)
self.unfreeze_backbone()
if peft_cfg.is_prompt_learning or peft_cfg.is_adaption_prompt:
raise MMLMisconfigurationException(f"Applying {peft_cfg.peft_type} is likely an unsupported PEFT type.")
self._peft_kwargs = peft_cfg.to_dict()
if isinstance(peft_cfg, LoraConfig) and peft_cfg.target_modules == "auto":
peft_cfg.target_modules = self.get_lora_compatible_layers(self.backbone)
logger.info(f"Auto detected {len(peft_cfg.target_modules)} compatible layers for LoRa in model backbone.")
pre_params = self.count_parameters(only_trainable=True)["backbone"]
self.backbone = peft.get_peft_model(model=self.backbone, peft_config=peft_cfg)
post_params = self.count_parameters(only_trainable=True)["backbone"]
logger.info(
f"After applying {peft_cfg.peft_type} from {intword(pre_params)} params only {intword(post_params)}"
f" remain trainable (={post_params / pre_params:.2%})."
)
[docs]
@staticmethod
def get_lora_compatible_layers(backbone: torch.nn.Module) -> List[str]:
"""
Helper function to extract all Lora compatible layers (from the peft library).
:param torch.nn.Module backbone: the model to extract layers from
:return: list of strings the correspond to the respective layer names
"""
layer_names = []
for name, module in backbone.named_modules():
# these are the currently LORA supported layers
if isinstance(module, (torch.nn.Linear, torch.nn.Embedding, torch.nn.Conv2d, Conv1D)):
layer_names.append(name)
return layer_names
[docs]
@staticmethod
def load_checkpoint(param_path: Union[Path, str]) -> "BaseModel":
"""
Load from a checkpoint. Be aware that MML uses its own checkpoint structure (different from the one in
`lightning <https://github.com/Lightning-AI/lightning>`_). Detail can be found in
:meth:`~mml.core.models.torch_base.BaseModel.save_checkpoint`.
:param Union[Path, str] param_path: path to load checkpoint from
:return:
"""
state = torch.load(param_path, weights_only=False)
model: BaseModel = hydra.utils.instantiate(dict(_target_=state["__target__"], **state["__init_kwargs__"]))
# for backward compatibility we check whether the keyword is present in the state
if "__peft_kwargs__" in state and len(state["__peft_kwargs__"]) > 0:
peft_cfg = PeftConfig.from_peft_type(**state["__peft_kwargs__"])
model.set_peft(peft_cfg)
model.backbone.load_state_dict(state["backbone"]) # type: ignore[union-attr]
model._frozen_params = state["__frozen_params__"]
for head_name, init_kwargs in zip(state["__head_names__"], state["__head_init_kwargs__"]):
head = model._create_head(**init_kwargs)
model.heads[head_name] = head
head.load_state_dict(state[head_name])
logger.info("Loaded MML checkpoint!")
logger.debug(f"@ {param_path}")
return model
[docs]
def save_checkpoint(self, param_path: Union[Path, str]) -> None:
"""
Save a model checkpoint.
:param Union[Path, str] param_path: path to store checkpoint
:return:
"""
state = {name: head.state_dict() for name, head in self.heads.items()}
state.update(
{
"backbone": self.backbone.state_dict(), # type: ignore[union-attr]
"__head_names__": list(self.heads.keys()),
"__init_kwargs__": self._init_kwargs,
"__target__": self.__class__,
"__frozen_params__": self._frozen_params,
"__head_init_kwargs__": self._head_init_kwargs,
"__peft_kwargs__": self._peft_kwargs,
}
)
torch.save(state, param_path)
logger.info("Saved checkpoint!")
logger.debug(f"@ {param_path}")
[docs]
class BaseHead(torch.nn.Module, ABC):
[docs]
def __init__(self, task_type: TaskType, num_classes: int, **kwargs: Any):
"""The base class for MML model heads."""
super(BaseHead, self).__init__()
self.task_type = task_type
self.num_classes = num_classes
[docs]
@abstractmethod
def forward(self, x: torch.Tensor) -> torch.Tensor:
pass