mml.core.scripts.schedulers.train_scheduler

class TrainingScheduler[source]

New version of the former “optimization” scheduler. Supports the following features: - model training - model prediction - model testing

In addition to the standard hooks (after_preparation_hook, before_finishing_hook) it provides additional hooks that may be overridden by inheriting schedulers: - before_training_hook - after_training_hook

It further allows for task nesting and cross validation.

__init__(cfg: DictConfig)[source]
after_preparation_hook()[source]
after_training_hook(datamodule: LightningDataModule, model: LightningModule, trainer: Trainer, fold: int, task_name: str) None[source]

This hook allows of setup modification after the model fitting ended (and potential lightning tuning). Allows to modify task_weights, data, trainer callbacks, etc. May be overwritten as part of inheriting from TrainScheduler.

Parameters:
  • datamodule (lightning.LightningDataModule) – the datamodule used

  • model (lightning.LightningModule) – the trained model

  • trainer (lightning.Trainer) – the used trainer

  • fold (int) – the used fold

  • task_name (str) – the pivot task

Returns:

None

before_finishing_hook()[source]
before_training_hook(datamodule: LightningDataModule, model: LightningModule, trainer: Trainer, fold: int, task_name: str) None[source]

This hook allows of setup modification before the model fitting starts (and also before lightning tuning). Allows to modify task_weights, data, trainer callbacks, etc. May be overwritten as part of inheriting from TrainScheduler.

Parameters:
  • datamodule (lightning.LightningDataModule) – the prepared datamodule (no setup run yet)

  • model (lightning.LightningModule) – the prepared model

  • trainer (lightning.Trainer) – the prepared trainer

  • fold (int) – the current fold

  • task_name (str) – the current task

Returns:

None

create_routine()[source]

This scheduler implements three sub-routines, training, testing and prediction. The routine takes care of cross validation and nesting.

predict_fold(task_name: str, fold: int, eval_on: str | None = None) None[source]
test_task(task_name: str, eval_on: str | None = None) None[source]
train_fold(task_name: str, fold: int) None[source]