mml.core.models.torch_base
- class BaseHead[source]
- class BaseModel[source]
-
- __init__(**kwargs)[source]
- The base class for MML models. Derived classes must implement the following methods:
_init_model()- for backbone initialization_create_head()- for head creationsupports()- reporting supported task typesforward()- the models forward pass through backbone and all headsforward_features()- alternative usage as feature extractor
- add_head(task_struct: TaskStruct, **kwargs: Any) None[source]
The functionality used to add heads to a model.
- Parameters:
task_struct (TaskStruct) – struct for the task to add a head for
kwargs (Any) – additional kwargs that will be forwarded
- count_parameters(only_trainable: bool = True) Dict[str, int][source]
Gives information on parameter count of the model.
- Parameters:
only_trainable (bool) – if True, only counts parameters that requires_grad.
- Returns:
a dict with component names as key (backbone or name of heads) and parameter count as value
- abstract forward(x: Tensor) Dict[str, Tensor][source]
Model forward functionality. Passes input through the backbone once and forwards output through each head. :param torch.tensor x: input tensor :return: a dictionary, with one entry per head, key is head name and value is head output
- abstract forward_features(x: Tensor) Tensor[source]
Feature extraction functionality. Only forwards features through the backbone and post-processes them to 1D. :param torch.tensor x: input tensor :return: a 1D tensor
- static get_lora_compatible_layers(backbone: Module) List[str][source]
Helper function to extract all Lora compatible layers (from the peft library).
- Parameters:
backbone (torch.nn.Module) – the model to extract layers from
- Returns:
list of strings the correspond to the respective layer names
- static load_checkpoint(param_path: Path | str) BaseModel[source]
Load from a checkpoint. Be aware that MML uses its own checkpoint structure (different from the one in lightning). Detail can be found in
save_checkpoint().- Parameters:
param_path (Union[Path, str]) – path to load checkpoint from
- Returns:
- save_checkpoint(param_path: Path | str) None[source]
Save a model checkpoint.
- Parameters:
param_path (Union[Path, str]) – path to store checkpoint
- Returns:
- set_peft(peft_cfg: PeftConfig) None[source]
Applies a PEFT (Parameter Efficient FineTuning) method to the model. Usually this will lead to adapters injected to the base model that complement existing weights. The advantage is that the majority of existing weights is frozen (the .requires_grad attribute of the tensors is set to false) while only the smaller adapters are kept trainable.
- Parameters:
peft_cfg (PeftConfig) – PEFTConfig instance, see huggingface/peft
- Returns:
None, since model is mofified in place