Skip to content

How to Train an Xpdeep Model#

You can train your Xpdeep model with the XpdeepModel object, using the Trainer interface.

The Trainer is fully customizable, and encapsulate your hyperparameters, callbacks or other training configurations.

Trainer related methods are asynchronous, meaning you don't have to wait for its completion before starting a new API call.

Future Release

The training process will be interruptible with ctrl+C

Requirements:

1. Build the Trainer#

Let's dive into the Trainer configuration.

loss: a loss function. The expected output shape of the loss function is a 1D tensor of shape (batch_size,). For ease of use, if the loss function has more than 1 output dimension, it will automatically be averaged as a 1-D tensor. The loss takes the predictions (task learner model output) as first argument and the preprocessed target as second argument.

Tip

Most of pytorch built-in loss function respect this convention when instantiated with reduction="none" (e.g. torch.nn.MSELoss(reduction="none")).

Future Release

It will be possible to specify custom "loss outputs" when designing models to leverage loss function that require other inputs than the predictions and targets.

optimizer: please provide any pytorch optimizer. The optimizer should be a partial optimizer that do not specify the "params" argument as the association must be done server side.

metrics: you can set your own metrics in a DictMetrics structure. Similarly to the loss function, metrics are computed between the model predictions and the targets. Any metric from torchmetrics that respect this convention is supported (as partial), and will be serialized with its name, again for security issues. If you add a metric from torchmetrics, two types of metrics will be automatically computed:

  1. Global model metrics TorchGlobalMetric, which are related to the model overall performance.
  2. Leaf metrics TorchLeafMetric, which provide detailed metrics specific to each leaf (or predictive region) of the model, offering more granular insights.

Please note that the dictionary keys will be used as metric names in XpViz.

callbacks: here you can define specifics callbacks.

Finally, you can provide a set of self explicable parameters, like max_epochs.

Tip

You can directly specify a TorchGlobalMetric or a TorchLeafMetric to compute global or per-leaf metrics only. In addition, both provide additional parameters like on_raw_data that allow you to compute the metrics on raw or preprocessed data.

For instance, if you scaled your data on the preprocessing stage, and use a MSE with TorchGlobalMetric with on_raw_data True, the MSE will be computed on the raw data and not the scaled data.

Warning

Currently, torch metrics that do not take predictions and target in the same representation space are not supported (e.g. one hot predictions and label target in classification tasks). Check xpdeep.metrics.zoo.multiclass_metrics for the corrected implementation of some common metrics

Future Release

For more flexibility, it will be possible to specify transform functions which will be applied to the model prediction and to the targets prior to metric computation.

import torch.nn as nn
from torch.optim import AdamW
from xpdeep.metrics.zoo.multiclass_metrics import MulticlassConfusionMatrix, MulticlassF1Score
from functools import partial
from xpdeep.metrics.metric import DictMetrics
from xpdeep.trainer.callbacks import EarlyStopping
from xpdeep.trainer.trainer import Trainer

metrics = DictMetrics(
    f1_score=partial(MulticlassConfusionMatrix, num_classes=3),
    accuracy=partial(MulticlassF1Score, num_classes=3),
)

trainer = Trainer(
    loss=nn.MSELoss(reduction="none"),
    optimizer=partial(AdamW, lr=0.01),
    callbacks=[EarlyStopping(monitoring_metric="Total loss", mode="minimize")],
    metrics=metrics,
    max_epochs=5
)
👀 Full file preview
from functools import partial

from build_model import xpdeep_model
from prepare_dataset import fitted_train_dataset, fitted_validation_dataset
from torch import nn
from torch.optim import AdamW
from torchmetrics import Accuracy, F1Score

from xpdeep.metrics.metric import DictMetrics
from xpdeep.trainer.callbacks import EarlyStopping
from xpdeep.trainer.trainer import Trainer

metrics = DictMetrics(
    f1_score=partial(F1Score, task="multiclass", num_classes=3),
    accuracy=partial(Accuracy, task="multiclass", num_classes=3),
)

trainer = Trainer(
    loss=nn.MSELoss(),
    optimizer=partial(AdamW, lr=0.01),
    callbacks=[EarlyStopping(monitoring_metric="Total loss", mode="minimize")],
    metrics=metrics,
    max_epochs=5,
)

trained_model = trainer.train(
    xpdeep_model, train_set=fitted_train_dataset, validation_set=fitted_validation_dataset, batch_size=32
)

2. Train the Model#

With your trainer, you can finally train the explainable model and get a trained model as a TrainedModelArtifact.

trained_model = trainer.train(xpdeep_model, train_set=fitted_train_dataset, validation_set=fitted_validation_dataset, batch_size=32)
👀 Full file preview
from functools import partial

from build_model import xpdeep_model
from prepare_dataset import fitted_train_dataset, fitted_validation_dataset
from torch import nn
from torch.optim import AdamW
from torchmetrics import Accuracy, F1Score

from xpdeep.metrics.metric import DictMetrics
from xpdeep.trainer.callbacks import EarlyStopping
from xpdeep.trainer.trainer import Trainer

metrics = DictMetrics(
    f1_score=partial(F1Score, task="multiclass", num_classes=3),
    accuracy=partial(Accuracy, task="multiclass", num_classes=3),
)

trainer = Trainer(
    loss=nn.MSELoss(),
    optimizer=partial(AdamW, lr=0.01),
    callbacks=[EarlyStopping(monitoring_metric="Total loss", mode="minimize")],
    metrics=metrics,
    max_epochs=5,
)

trained_model = trainer.train(
    xpdeep_model, train_set=fitted_train_dataset, validation_set=fitted_validation_dataset, batch_size=32
)

Under the hood, the trained model will be saved as a trained model artifact within your Project.

You can get training logs in your terminal.

Future Release

Insights and logs will be provided as artifacts.

3. Evaluate the Model#

Model evaluation to assess performance is carried out through the Explainer interface using the inference method. A wide range of metrics is supported, see Explain.