Skip to content

trainer

Implement the trainer interface.

Modules:

Name Description
jobs

Jobs utilities.

Classes:

Name Description
CancelMessage

Cancel message fro training generator.

Trainer

Interface to train a model.

Attributes:

Name Type Description
T

T = TypeVar('T') #

CancelMessage #

Cancel message fro training generator.

Trainer(loss: _Loss, optimizer: partial[Optimizer], *, max_epochs: int, start_epoch: int | None = None, callbacks: list[Callback] | None = None, metrics: DictMetrics | None = None) #

Interface to train a model.

Initialize the trainer.

Parameters:

Name Type Description Default

loss #

_Loss

The loss to use during the training.

required

optimizer #

partial[Optimizer]

A partial torch optimizer.

required

max_epochs #

int

Max epochs.

required

start_epoch #

int | None

The start epoch when the training is resumed.

None

callbacks #

list[Callback] | None

Optional list of callbacks to use during training.

None

metrics #

DictMetrics | None

Optional dictionary of metrics to use during training.

None

Methods:

Name Description
build_loss

Build a loss object - infer required size from model and fitted_schema.

train

Train the model following the config and the dataset.

Attributes:

Name Type Description
optimizer
max_epochs
start_epoch
callbacks
metrics
Source code in src/xpdeep/trainer/trainer.py
def __init__(  # noqa:PLR0913
    self,
    loss: _Loss,
    optimizer: partial[Optimizer],
    *,
    max_epochs: int,
    start_epoch: int | None = None,
    callbacks: list[Callback] | None = None,
    metrics: DictMetrics | None = None,
) -> None:
    """Initialize the trainer.

    Parameters
    ----------
    loss : _Loss
        The loss to use during the training.
    optimizer : partial[Optimizer]
        A partial torch optimizer.
    max_epochs : int
        Max epochs.
    start_epoch : int | None, default None
        The start epoch when the training is resumed.
    callbacks : list[Callback] | None, default None
        Optional list of callbacks to use during training.
    metrics : DictMetrics | None, default None
        Optional dictionary of metrics to use during training.
    """
    self._loss = loss
    self.optimizer = optimizer
    self.max_epochs = max_epochs
    self.start_epoch = start_epoch
    self.callbacks = callbacks
    self.metrics = metrics

optimizer = optimizer #

max_epochs = max_epochs #

start_epoch = start_epoch #

callbacks = callbacks #

metrics = metrics #

build_loss(loss: _Loss, fitted_schema: FittedSchema, model: XpdeepModel) -> MainLossRequestBody #

Build a loss object - infer required size from model and fitted_schema.

Source code in src/xpdeep/trainer/trainer.py
@staticmethod
def build_loss(loss: _Loss, fitted_schema: FittedSchema, model: XpdeepModel) -> MainLossRequestBody:
    """Build a loss object - infer required size from model and fitted_schema."""
    target_size = fitted_schema.target_size[1:]
    output_size = model.get_output_size(fitted_schema)
    return MainLossRequestBody(
        torch_loss=TorchModel.from_torch_loss(loss, target_size=target_size, output_size=output_size),
    )

train(model: XpdeepModel, train_set: FittedParquetDataset, *, validation_set: FittedParquetDataset | None = None, with_cache: bool = False, batch_size: int | None = None, seed: int | None = None) -> TrainedModelArtifact #

Train the model following the config and the dataset.

Source code in src/xpdeep/trainer/trainer.py
@initialized_client_verification
@initialized_project_verification
def train(  # noqa: PLR0913
    self,
    model: XpdeepModel,
    train_set: FittedParquetDataset,
    *,
    validation_set: FittedParquetDataset | None = None,
    with_cache: bool = False,
    batch_size: int | None = None,
    seed: int | None = None,
) -> TrainedModelArtifact:
    """Train the model following the config and the dataset."""
    warnings.warn(
        "Interrupting training process will cancel the job. Currently the job will still run as a "
        "background process even after interrupting the 'train' method.",
        category=FutureWarning,
        stacklevel=1,
    )

    error_msg = "Training failed"

    body = TrainedModelCreateRequestBody(
        trainer=self._as_request_body(train_set.fitted_schema, model),
        model=model._as_request_body,
        train_set=train_set._as_request_body,
        batch_config=BuildBatchConfigRequestBody(batch_size=batch_size, seed=seed),
        read_dataset_config=BuildReadDatasetConfigRequestBody(),
        validation_set=validation_set._as_request_body if validation_set is not None else None,
        with_cache=with_cache,
    )

    training_model_job = cast(
        JobModel,
        create_trained_model.sync(Project.CURRENT.get().model.id, body=body, client=ClientFactory.CURRENT.get()()),
    )

    jobs.print_job_progress(training_model_job.id)

    training_model_job = cast(
        JobModel,
        get_one_job.sync(
            Project.CURRENT.get().model.id, training_model_job.id, client=ClientFactory.CURRENT.get()()
        ),
    )
    if training_model_job.status == JobStatus.ERROR:
        raise ApiError(cast(JobModelResultsType0, training_model_job.results)["error_detail"])

    if training_model_job.results is None:
        raise ApiError(error_msg)

    return TrainedModelArtifact.from_dict(
        cast(
            TrainedModelModel,
            get_one_trained_model.sync(
                Project.CURRENT.get().model.id,
                training_model_job.results["id"],
                client=ClientFactory.CURRENT.get()(),
            ),
        ).to_dict()
    )