# LICENSE HEADER MANAGED BY add-license-header
#
# SPDX-FileCopyrightText: Copyright 2024 German Cancer Research Center (DKFZ) and contributors.
# SPDX-License-Identifier: MIT
#
import logging
from typing import Any, Dict, Optional
import segmentation_models_pytorch as smp
import torch
import torch.nn as nn
from segmentation_models_pytorch.base import ClassificationHead, SegmentationHead, initialization
from mml.core.data_loading.task_attributes import RGBInfo, TaskType
from mml.core.models.torch_base import BaseHead, BaseModel
logger = logging.getLogger(__name__)
[docs]
class SMPGenericModel(BaseModel):
[docs]
def __init__(self, **kwargs):
self.arch_name: Optional[str] = None # architecture, set during _init_model
self.encoder_name: Optional[str] = None # encoder, set during _init_model
self.weights: Optional[str] = None # encoder pretraining weights, set during _init_model
self.feature_channels: Optional[int] = None # encoder output size, set during _init_model
self.out_channels: Optional[int] = None # decoder output size, set during _init_model
super(SMPGenericModel, self).__init__(**kwargs)
# only used for feature extraction
self.pooling = nn.AdaptiveAvgPool2d(1)
[docs]
def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
features = self.backbone.encoder(x)
decoder_output = self.backbone.decoder(*features)
return {
name: head(decoder_output if head.task_type == TaskType.SEMANTIC_SEGMENTATION else features[-1])
for name, head in self.heads.items()
}
[docs]
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
return self.pooling(self.backbone.encoder(x)[-1]).squeeze(3).squeeze(2)
def _init_model(self, arch: str, weights: Optional[str], encoder: str = "resnet34", **kwargs: Any) -> None:
model = smp.create_model(
arch=arch,
encoder_name=encoder,
encoder_weights=weights,
in_channels=3,
classes=1, # default segmentation head will be discarded
)
settings = smp.encoders.encoders[encoder]["pretrained_settings"][weights if weights else "imagenet"]
self.input_size = settings.get("input_size")
if weights:
self.required_mean = RGBInfo(*settings.get("mean"))
self.required_std = RGBInfo(*settings.get("std"))
self.arch_name = arch
self.encoder_name = encoder
self.weights = weights
self.backbone = model
self.feature_channels = self.backbone.encoder.out_channels[-1]
self.out_channels = self.backbone.segmentation_head[1].in_channels
def _create_head(self, task_type: TaskType, num_classes: int, **kwargs: Any) -> BaseHead:
return SMPHead(
task_type=task_type,
num_classes=num_classes,
num_features=self.out_channels if task_type == TaskType.SEMANTIC_SEGMENTATION else self.feature_channels,
)
[docs]
def supports(self, task_type: TaskType) -> bool:
"""SMP support classification and segmentation tasks."""
return task_type in [
TaskType.CLASSIFICATION,
TaskType.MULTILABEL_CLASSIFICATION,
TaskType.SEMANTIC_SEGMENTATION,
]
[docs]
class SMPHead(BaseHead):
[docs]
def __init__(self, task_type: TaskType, num_classes: int, num_features: int):
super().__init__(task_type=task_type, num_classes=num_classes)
if task_type == TaskType.SEMANTIC_SEGMENTATION:
self.head = SegmentationHead(in_channels=num_features, out_channels=num_classes, activation="softmax2d")
else:
self.head = ClassificationHead(in_channels=num_features, classes=num_classes)
initialization.initialize_head(self.head)
[docs]
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.head(x)