Skip to content

callbacks

Callbacks for training.

Classes:

Name Description
EarlyStopping

Initialize the EarlyStopping callback.

Scheduler

Initialize a scheduler object to encapsulate different torch schedulers.

ModelCheckpoint

Model checkpoint initialization.

Attributes:

Name Type Description
Callback

Callback = EarlyStopping | Scheduler | ModelCheckpoint #

EarlyStopping #

Initialize the EarlyStopping callback.

The given monitoring metric will be used on the validation set to evaluate the training process. If no validation set provided when calling the Trainer.train method, the monitoring metric will be used on training set and can lead to an overfitted trained model.

Note: required at least 1 pruning step to be performed to be activated. If the model was not pruned yet, the training won't stop.

Parameters:

Name Type Description Default

monitoring_metric #

str

Monitoring metric required for early stopping, for instance "Total loss", which is the sum between the mean explain loss and the mean loss. The specified metric is by default computed on the validation set. If no validation set is provided, the only available metric is "Total loss", the sum between the mean explain loss and the mean loss on the train set.

required

mode #

Literal['maximize', 'minimize']

Whether to "maximize" or "minimize" the provided metric.

required

patience #

int

Number of epochs to wait with no improvement of the monitoring value.

3.

min_delta #

float

Minimum delta between two monitoring values to consider an improvement.

0.

Attributes:

Name Type Description
type_ str

Methods:

Name Description
to_model

To callback model.

type_: str = field(init=False, default='EARLY_STOPPING') #

mode: Literal['maximize', 'minimize'] = field(validator=(lambda self, attribute, value: value in [(str(mode_)) for mode_ in EarlyStoppingTrainCallbackTrainCallbackMode]), converter=EarlyStoppingTrainCallbackTrainCallbackMode) #

monitoring_metric: str #

patience: int = 3 #

min_delta: float = 0.0 #

to_model() -> EarlyStoppingTrainCallback #

To callback model.

Source code in src/xpdeep/trainer/callbacks.py
def to_model(self) -> EarlyStoppingTrainCallback:
    """To callback model."""
    return EarlyStoppingTrainCallback(**_attrs_asdict(self))

Scheduler #

Initialize a scheduler object to encapsulate different torch schedulers.

Parameters:

Name Type Description Default

pre_scheduler #

partial[LRScheduler]

Based torch lr scheduler to be instantiated. Should not contain the optimizer as xpdeep use the trainer's optimizer for the scheduler internally.

required

step_method #

Literal['batch', 'epoch']

"epoch" or "batch".

required

monitoring_metric #

str

Monitoring metric required for the step method, for instance "Total loss", which is the sum between the mean explain loss and the mean loss. The specified metric is by default computed on the validation set. If no validation set is provided, the only available metric is "Total loss", the sum between the mean explain loss and the mean loss on the train set.

required

Attributes:

Name Type Description
type_ str

Methods:

Name Description
to_model

As SchedulerRequestBody.

type_: str = field(init=False, default='SCHEDULER') #

pre_scheduler: partial[LRScheduler | ReduceLROnPlateau] #

step_method: Literal['batch', 'epoch'] = field(validator=(lambda self, attribute, value: value in [(str(step_method_)) for step_method_ in SchedulerTrainCallbackInputSchedulerTrainCallbackStepMethod]), converter=SchedulerTrainCallbackInputSchedulerTrainCallbackStepMethod) #

monitoring_metric: str #

to_model() -> SchedulerTrainCallbackInput #

As SchedulerRequestBody.

Source code in src/xpdeep/trainer/callbacks.py
def to_model(self) -> SchedulerTrainCallbackInput:
    """As SchedulerRequestBody."""
    self_as_dict = _attrs_asdict(self)

    self_as_dict["pre_scheduler"] = TrustedObjectInput.from_dict({
        "reconstructor": "partial",
        "class_": self.pre_scheduler.func.__name__,
        "module": self.pre_scheduler.func.__module__,
        "state": {
            "args": self.pre_scheduler.args,
            "kwargs": self.pre_scheduler.keywords,
        },
    })
    self_as_dict["step_method"] = SchedulerTrainCallbackInputSchedulerTrainCallbackStepMethod(self.step_method)

    return SchedulerTrainCallbackInput(**self_as_dict)

ModelCheckpoint #

Model checkpoint initialization.

Parameters:

Name Type Description Default

monitoring_metric #

str

Monitoring metric required for the step method, for instance "Total loss", which is the sum between the mean explain loss and the mean loss. The specified metric is by default computed on the validation set. If no validation set is provided, the only available metric is "Total loss", the sum between the mean explain loss and the mean loss on the train set.

required

save_every_epoch #

int

How often to save the model. If None, only save the best checkpoint.

1

mode #

Literal['maximize', 'minimize']

Whether to "maximize" or "minimize" the provided metric.

required

Attributes:

Name Type Description
type_ str

Methods:

Name Description
to_model

Parse the object to a RequestBody for the backend.

type_: str = field(init=False, default='MODEL_CHECKPOINT') #

monitoring_metric: str #

mode: Literal['maximize', 'minimize'] = field(validator=(lambda self, attribute, value: value in [(str(mode_)) for mode_ in ModelCheckpointTrainCallbackTrainCallbackMode]), converter=ModelCheckpointTrainCallbackTrainCallbackMode) #

save_every_epoch: int = 1 #

to_model() -> ModelCheckpointTrainCallback #

Parse the object to a RequestBody for the backend.

Source code in src/xpdeep/trainer/callbacks.py
def to_model(self) -> ModelCheckpointTrainCallback:
    """Parse the object to a RequestBody for the backend."""
    self_as_dict = _attrs_asdict(self)
    self_as_dict["mode"] = ModelCheckpointTrainCallbackTrainCallbackMode(self.mode)

    return ModelCheckpointTrainCallback(**self_as_dict)