mml.core.models.torch_base
- class BaseHead[source]
- class BaseModel[source]
-
- __init__(**kwargs)[source]
- The base class for MML models. Derived classes must implement the following methods:
_init_model()- for backbone initialization_create_head()- for head creationsupports()- reporting supported task typesforward()- the models forward pass through backbone and all headsforward_features()- alternative usage as feature extractor
- add_head(task_struct: TaskStruct, **kwargs: Any) None[source]
The functionality used to add heads to a model.
- Parameters:
task_struct (TaskStruct) – struct for the task to add a head for
kwargs (Any) – additional kwargs that will be forwarded
- count_parameters(only_trainable: bool = True) Dict[str, int][source]
Gives information on parameter count of the model.
- Parameters:
only_trainable (bool) – if True, only counts parameters that requires_grad.
- Returns:
a dict with component names as key (backbone or name of heads) and parameter count as value
- abstract forward(x: Tensor) Dict[str, Tensor][source]
Model forward functionality. Passes input through the backbone once and forwards output through each head. :param torch.tensor x: input tensor :return: a dictionary, with one entry per head, key is head name and value is head output
- abstract forward_features(x: Tensor) Tensor[source]
Feature extraction functionality. Only forwards features through the backbone and post-processes them to 1D. :param torch.tensor x: input tensor :return: a 1D tensor
- static load_checkpoint(param_path: Path | str) BaseModel[source]
Load from a checkpoint.
- Parameters:
param_path (Union[Path, str]) – path to load checkpoint from
- Returns:
- save_checkpoint(param_path: Path | str) None[source]
Save a model checkpoint.
- Parameters:
param_path (Union[Path, str]) – path to store checkpoint
- Returns: