How to Train an Xpdeep Model#
You can train your Xpdeep model with the XpdeepModel
object, using the Trainer
interface.
The Trainer
is fully customizable, and encapsulate your
hyperparameters, callbacks or other training configurations.
Trainer related methods are asynchronous, meaning you don't have to wait for its completion before starting a new API call.
Future Release
The training process will be interruptible with ctrl+C
Requirements:
- an
XpdeepModel
, see build model. - a train / validation sets as
FittedParquetDataset
, see create dataset.
1. Build the Trainer#
Let's dive into the Trainer
configuration.
loss
: a loss function. The expected output shape of the loss function is a 1D tensor of shape (batch_size,).
For ease of use, if the loss function has more than 1 output dimension, it will automatically be averaged as a 1-D tensor.
The loss takes the predictions (task learner model output) as first argument
and the preprocessed target as second argument.
Tip
Most of pytorch built-in loss function respect this convention when instantiated with reduction="none"
(e.g. torch.nn.MSELoss(reduction="none")
).
Future Release
It will be possible to specify custom "loss outputs" when designing models to leverage loss function that require other inputs than the predictions and targets.
optimizer
: please provide any pytorch optimizer. The optimizer should be a partial
optimizer that do not specify
the "params" argument as the association must be done server side.
metrics
: you can set your own metrics in a DictMetrics
structure. Similarly to the loss function, metrics are computed
between the model predictions and the targets. Any metric from
torchmetrics that respect this convention is supported (as partial), and will be serialized with its name, again for security issues.
If you add a metric from torchmetrics
, two types of metrics will be automatically computed:
- Global model metrics
TorchGlobalMetric
, which are related to the model overall performance. - Leaf metrics
TorchLeafMetric
, which provide detailed metrics specific to each leaf (or predictive region) of the model, offering more granular insights.
Please note that the dictionary keys will be used as metric names in XpViz
.
callbacks
: here you can define specifics callbacks.
Finally, you can provide a set of self explicable parameters, like max_epochs
.
Tip
You can directly specify a TorchGlobalMetric
or a TorchLeafMetric
to compute global or per-leaf metrics only.
In addition, both provide additional parameters like on_raw_data
that allow you to compute
the metrics on raw or preprocessed data.
For instance, if you scaled your data on the preprocessing stage, and use a MSE with TorchGlobalMetric
with
on_raw_data
True, the MSE will be computed on the raw data and not the scaled data.
Warning
Currently, torch metrics that do not take predictions and target in the same representation space are not supported (e.g. one hot predictions and label target in classification tasks). Check xpdeep.metrics.zoo.multiclass_metrics for the corrected implementation of some common metrics
Future Release
For more flexibility, it will be possible to specify transform functions which will be applied to the model prediction and to the targets prior to metric computation.
import torch.nn as nn
from torch.optim import AdamW
from xpdeep.metrics.zoo.multiclass_metrics import MulticlassConfusionMatrix, MulticlassF1Score
from functools import partial
from xpdeep.metrics.metric import DictMetrics
from xpdeep.trainer.callbacks import EarlyStopping
from xpdeep.trainer.trainer import Trainer
metrics = DictMetrics(
f1_score=partial(MulticlassConfusionMatrix, num_classes=3),
accuracy=partial(MulticlassF1Score, num_classes=3),
)
trainer = Trainer(
loss=nn.MSELoss(reduction="none"),
optimizer=partial(AdamW, lr=0.01),
callbacks=[EarlyStopping(monitoring_metric="Total loss", mode="minimize")],
metrics=metrics,
max_epochs=5
)
👀 Full file preview
from functools import partial
from build_model import xpdeep_model
from prepare_dataset import fitted_train_dataset, fitted_validation_dataset
from torch import nn
from torch.optim import AdamW
from torchmetrics import Accuracy, F1Score
from xpdeep.metrics.metric import DictMetrics
from xpdeep.trainer.callbacks import EarlyStopping
from xpdeep.trainer.trainer import Trainer
metrics = DictMetrics(
f1_score=partial(F1Score, task="multiclass", num_classes=3),
accuracy=partial(Accuracy, task="multiclass", num_classes=3),
)
trainer = Trainer(
loss=nn.MSELoss(),
optimizer=partial(AdamW, lr=0.01),
callbacks=[EarlyStopping(monitoring_metric="Total loss", mode="minimize")],
metrics=metrics,
max_epochs=5,
)
trained_model = trainer.train(
xpdeep_model, train_set=fitted_train_dataset, validation_set=fitted_validation_dataset, batch_size=32
)
2. Train the Model#
With your trainer, you can finally train the explainable model and get a trained model as a TrainedModelArtifact.
trained_model = trainer.train(xpdeep_model, train_set=fitted_train_dataset, validation_set=fitted_validation_dataset, batch_size=32)
👀 Full file preview
from functools import partial
from build_model import xpdeep_model
from prepare_dataset import fitted_train_dataset, fitted_validation_dataset
from torch import nn
from torch.optim import AdamW
from torchmetrics import Accuracy, F1Score
from xpdeep.metrics.metric import DictMetrics
from xpdeep.trainer.callbacks import EarlyStopping
from xpdeep.trainer.trainer import Trainer
metrics = DictMetrics(
f1_score=partial(F1Score, task="multiclass", num_classes=3),
accuracy=partial(Accuracy, task="multiclass", num_classes=3),
)
trainer = Trainer(
loss=nn.MSELoss(),
optimizer=partial(AdamW, lr=0.01),
callbacks=[EarlyStopping(monitoring_metric="Total loss", mode="minimize")],
metrics=metrics,
max_epochs=5,
)
trained_model = trainer.train(
xpdeep_model, train_set=fitted_train_dataset, validation_set=fitted_validation_dataset, batch_size=32
)
Under the hood, the trained model will be saved as a trained model artifact within your Project
.
You can get training logs in your terminal.
Future Release
Insights and logs will be provided as artifacts.
3. Evaluate the Model#
Model evaluation to assess performance is carried out through the Explainer
interface using the inference method.
A wide range of metrics is supported, see Explain.