mml.core.scripts.schedulers.train_scheduler

class TrainingScheduler[source]

New version of the former “optimization” scheduler. Supports the following features: - model training - model prediction - model testing

In addition to the standard hooks (after_preparation_hook, before_finishing_hook) it provides additional hooks that may be overridden by inheriting schedulers: - before_training_hook - after_training_hook

It further allows for task nesting and cross validation.

__init__(cfg: DictConfig)[source]
after_preparation_hook()[source]
after_training_hook(datamodule: LightningDataModule, model: LightningModule, trainer: Trainer, fold: int, task_name: str) None[source]

This hook allows of setup modification after the model fitting ended (and potential lightning tuning). Allows to modify weights, data, trainer callbacks, etc. May be overwritten as part of inheriting from TrainScheduler.

Parameters:
  • datamodule (lightning.LightningDataModule) – the datamodule used

  • model (lightning.LightningModule) – the trained model

  • trainer (lightning.Trainer) – the used trainer

  • fold (int) – the used fold

  • task_name (str) – the pivot task

Returns:

None

before_finishing_hook()[source]
before_training_hook(datamodule: LightningDataModule, model: LightningModule, trainer: Trainer, fold: int, task_name: str) None[source]

This hook allows of setup modification before the model fitting starts (and also before lightning tuning). Allows to modify weights, data, trainer callbacks, etc. May be overwritten as part of inheriting from TrainScheduler.

Parameters:
  • datamodule (lightning.LightningDataModule) – the prepared datamodule (no setup run yet)

  • model (lightning.LightningModule) – the prepared model

  • trainer (lightning.Trainer) – the prepared trainer

  • fold (int) – the current fold

  • task_name (str) – the current task

Returns:

None

create_routine()[source]

This scheduler implements three sub-routines, training, testing and prediction. The routine takes care of cross validation and nesting.

predict_fold(task_name: str, fold: int, eval_on: str | None = None) None[source]
test_task(task_name: str, eval_on: str | None = None) None[source]
train_fold(task_name: str, fold: int) None[source]