mml.core.models.torch_base

class BaseHead[source]

Bases: Module, ABC

__init__(task_type: TaskType, num_classes: int, **kwargs: Any)[source]

The base class for MML model heads.

abstract forward(x: Tensor) Tensor[source]
class BaseModel[source]

Bases: Module, ABC

__init__(**kwargs)[source]
The base class for MML models. Derived classes must implement the following methods:
  • _init_model() - for backbone initialization

  • _create_head() - for head creation

  • supports() - reporting supported task types

  • forward() - the models forward pass through backbone and all heads

  • forward_features() - alternative usage as feature extractor

add_head(task_struct: TaskStruct, **kwargs: Any) None[source]

The functionality used to add heads to a model.

Parameters:
  • task_struct (TaskStruct) – struct for the task to add a head for

  • kwargs (Any) – additional kwargs that will be forwarded

count_parameters(only_trainable: bool = True) Dict[str, int][source]

Gives information on parameter count of the model.

Parameters:

only_trainable (bool) – if True, only counts parameters that requires_grad.

Returns:

a dict with component names as key (backbone or name of heads) and parameter count as value

abstract forward(x: Tensor) Dict[str, Tensor][source]

Model forward functionality. Passes input through the backbone once and forwards output through each head. :param torch.tensor x: input tensor :return: a dictionary, with one entry per head, key is head name and value is head output

abstract forward_features(x: Tensor) Tensor[source]

Feature extraction functionality. Only forwards features through the backbone and post-processes them to 1D. :param torch.tensor x: input tensor :return: a 1D tensor

freeze_backbone() None[source]

Freezes all backbone parameters.

static get_lora_compatible_layers(backbone: Module) List[str][source]

Helper function to extract all Lora compatible layers (from the peft library).

Parameters:

backbone (torch.nn.Module) – the model to extract layers from

Returns:

list of strings the correspond to the respective layer names

static load_checkpoint(param_path: Path | str) BaseModel[source]

Load from a checkpoint. Be aware that MML uses its own checkpoint structure (different from the one in lightning). Detail can be found in save_checkpoint().

Parameters:

param_path (Union[Path, str]) – path to load checkpoint from

Returns:

save_checkpoint(param_path: Path | str) None[source]

Save a model checkpoint.

Parameters:

param_path (Union[Path, str]) – path to store checkpoint

Returns:

set_peft(peft_cfg: PeftConfig) None[source]

Applies a PEFT (Parameter Efficient FineTuning) method to the model. Usually this will lead to adapters injected to the base model that complement existing weights. The advantage is that the majority of existing weights is frozen (the .requires_grad attribute of the tensors is set to false) while only the smaller adapters are kept trainable.

Parameters:

peft_cfg (PeftConfig) – PEFTConfig instance, see huggingface/peft

Returns:

None, since model is mofified in place

abstract supports(task_type: TaskType) bool[source]

Whether the model supports a given task type.

Parameters:

task_type (TaskType)

Returns:

true iff task type is supported by model

unfreeze_backbone() None[source]

Unfreezes previously frozen backbone parameters.