trainer
Implement the trainer interface.
Modules:
Name | Description |
---|---|
jobs |
Jobs utilities. |
Classes:
Name | Description |
---|---|
CancelMessage |
Cancel message fro training generator. |
Trainer |
Interface to train a model. |
Attributes:
Name | Type | Description |
---|---|---|
T |
|
T = TypeVar('T')
#
CancelMessage
#
Cancel message fro training generator.
Trainer(loss: _Loss, optimizer: partial[Optimizer], *, max_epochs: int, start_epoch: int | None = None, callbacks: list[Callback] | None = None, metrics: DictMetrics | None = None)
#
Interface to train a model.
Initialize the trainer.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
_Loss
|
The loss to use during the training. |
required |
|
partial[Optimizer]
|
A partial torch optimizer. |
required |
|
int
|
Max epochs. |
required |
|
int | None
|
The start epoch when the training is resumed. |
None
|
|
list[Callback] | None
|
Optional list of callbacks to use during training. |
None
|
|
DictMetrics | None
|
Optional dictionary of metrics to use during training. |
None
|
Methods:
Name | Description |
---|---|
build_loss |
Build a loss object - infer required size from model and fitted_schema. |
train |
Train the model following the config and the dataset. |
Attributes:
Name | Type | Description |
---|---|---|
optimizer |
|
|
max_epochs |
|
|
start_epoch |
|
|
callbacks |
|
|
metrics |
|
Source code in src/xpdeep/trainer/trainer.py
optimizer = optimizer
#
max_epochs = max_epochs
#
start_epoch = start_epoch
#
callbacks = callbacks
#
metrics = metrics
#
build_loss(loss: _Loss, fitted_schema: FittedSchema, model: XpdeepModel) -> MainLossRequestBody
#
Build a loss object - infer required size from model and fitted_schema.
Source code in src/xpdeep/trainer/trainer.py
train(model: XpdeepModel, train_set: FittedParquetDataset, *, validation_set: FittedParquetDataset | None = None, with_cache: bool = False, batch_size: int | None = None, seed: int | None = None) -> TrainedModelArtifact
#
Train the model following the config and the dataset.