mml.core.models.torch_base

class BaseHead[source]

Bases: Module, ABC

__init__(task_type: TaskType, num_classes: int, **kwargs: Any)[source]

The base class for MML model heads.

abstract forward(x: Tensor) Tensor[source]
class BaseModel[source]

Bases: Module, ABC

__init__(**kwargs)[source]
The base class for MML models. Derived classes must implement the following methods:
  • _init_model() - for backbone initialization

  • _create_head() - for head creation

  • supports() - reporting supported task types

  • forward() - the models forward pass through backbone and all heads

  • forward_features() - alternative usage as feature extractor

add_head(task_struct: TaskStruct, **kwargs: Any) None[source]

The functionality used to add heads to a model.

Parameters:
  • task_struct (TaskStruct) – struct for the task to add a head for

  • kwargs (Any) – additional kwargs that will be forwarded

count_parameters(only_trainable: bool = True) Dict[str, int][source]

Gives information on parameter count of the model.

Parameters:

only_trainable (bool) – if True, only counts parameters that requires_grad.

Returns:

a dict with component names as key (backbone or name of heads) and parameter count as value

abstract forward(x: Tensor) Dict[str, Tensor][source]

Model forward functionality. Passes input through the backbone once and forwards output through each head. :param torch.tensor x: input tensor :return: a dictionary, with one entry per head, key is head name and value is head output

abstract forward_features(x: Tensor) Tensor[source]

Feature extraction functionality. Only forwards features through the backbone and post-processes them to 1D. :param torch.tensor x: input tensor :return: a 1D tensor

freeze_backbone() None[source]

Freezes all backbone parameters.

static load_checkpoint(param_path: Path | str) BaseModel[source]

Load from a checkpoint.

Parameters:

param_path (Union[Path, str]) – path to load checkpoint from

Returns:

save_checkpoint(param_path: Path | str) None[source]

Save a model checkpoint.

Parameters:

param_path (Union[Path, str]) – path to store checkpoint

Returns:

abstract supports(task_type: TaskType) bool[source]

Whether the model supports a given task type.

Parameters:

task_type (TaskType)

Returns:

true iff task type is supported by model

unfreeze_backbone() None[source]

Unfreezes previously frozen backbone parameters.