From a pytorch model to a deep explainable model#
For a quick introduction to the Xpdeep APIs, this section demonstrates, on the Bike dataset, how to adapt a standard deep model's PyTorch code to transition to designing an explainable deep model.
We will review the key steps involved in designing a deep model, from architecture specification and training to generating explanations (for Xpdeep).
For each step in building a deep model, we provide:
-
Tabs labeled "SOTA and Xpdeep" for code that is identical for both the SOTA deep model and the Xpdeep explainable model.
-
Tabs labeled "Xpdeep" for code specific to the Xpdeep explainable model.
1. Project Setup#
Setup Api Key and URL#
Create a Project#
2. Data preparation#
Read Raw Data#
import pandas as pd
test_data = pd.read_csv("test.csv")
train_val_data = pd.read_csv("train.csv")
test_data = test_data.drop(columns=["atemp"])
train_val_data = train_val_data.drop(columns=["casual", "atemp", "registered"])
for dataset in [test_data, train_val_data]:
dataset["datetime"] = pd.to_datetime(dataset["datetime"])
dataset["year"] = dataset["datetime"].dt.year
dataset["month"] = dataset["datetime"].dt.month
dataset["hour"] = dataset["datetime"].dt.hour
dataset["weekday"] = dataset["datetime"].dt.weekday
dataset.drop(columns=["datetime"], inplace=True)
Split Data#
Conversion to Parquet Format#
import pyarrow as pa
import pyarrow.parquet as pq
# Convert to pyarrow Table format
train_table = pa.Table.from_pandas(train_data, preserve_index=False)
val_table = pa.Table.from_pandas(val_data, preserve_index=False)
test_table = pa.Table.from_pandas(test_data, preserve_index=False)
# Save each split as ".parquet" file
pq.write_table(train_table, "train.parquet")
pq.write_table(val_table, "val.parquet")
pq.write_table(test_table, "test.parquet")
Upload#
Preprocess Data#
from sklearn.preprocessing import OneHotEncoder, StandardScaler
import numpy as np
# Fit preprocessors
numerical_features = ["temp", "humidity", "windspeed"]
categorical_features = ["season", "holiday", "workingday", "weather", "year", "month", "hour", "weekday"]
target_feature = "count"
numerical_features_standard_scaler = StandardScaler().fit(train_data[numerical_features])
categorical_features_encoders = {}
for category in categorical_features:
categorical_features_encoders[category] = OneHotEncoder(sparse_output=False).fit(train_data[[category]])
target_feature_encoder = StandardScaler().fit(train_data[[target_feature]])
# Transform data
x_train = np.concatenate(
[numerical_features_standard_scaler.transform(train_data[numerical_features])]
+
[categorical_features_encoders[feature].transform(train_data[[feature]]) for feature in categorical_features],
axis=1
)
y_train = target_feature_encoder.transform(train_data[[target_feature]])
x_test = np.concatenate(
[numerical_features_standard_scaler.transform(test_data[numerical_features])]
+
[categorical_features_encoders[feature].transform(test_data[[feature]]) for feature in categorical_features],
axis=1
)
y_test = target_feature_encoder.transform(test_data[[target_feature]])
x_val = np.concatenate(
[numerical_features_standard_scaler.transform(val_data[numerical_features])]
+
[categorical_features_encoders[feature].transform(val_data[[feature]]) for feature in categorical_features],
axis=1
)
y_val = target_feature_encoder.transform(val_data[[target_feature]])
# input and output sizes
input_size = x_train.shape[1]
target_size = y_train.shape[1]
from xpdeep.dataset.parquet_dataset import FittedParquetDataset, ParquetDataset
# 1/ Create Analyzed Parquet on Train Dataset
train_dataset = ParquetDataset(
split_name="train",
identifier_name="my_local_dataset",
path=directory["train_set_path"],
)
analyzed_train_dataset = train_dataset.analyze(target_names=["count"])
print(analyzed_train_dataset.analyzed_schema)
#2/ Create Fitted Parquet Datasets
fit_train_dataset = analyzed_train_dataset.fit()
fit_test_dataset = FittedParquetDataset(
split_name="test",
identifier_name="my_local_dataset",
path=directory["test_set_path"],
fitted_schema=fit_train_dataset.fitted_schema,
)
fit_val_dataset = FittedParquetDataset(
split_name="validation",
identifier_name="my_local_dataset",
path=directory["val_set_path"],
fitted_schema=fit_train_dataset.fitted_schema,
)
# input and output sizes
input_size = fit_train_dataset.fitted_schema.input_size[1]
target_size = fit_train_dataset.fitted_schema.target_size[1]
3. Model Construction#
Architecture Specification#
Model Instantiation#
from xpdeep.model.model_builder import ModelDecisionGraphParameters
from xpdeep.model.xpdeep_model import XpdeepModel
# Explanation Architecture
explanation_architecture = ModelDecisionGraphParameters(
graph_depth=3,
discrimination_weight=0.1,
target_homogeneity_weight=0.1,
target_homogeneity_pruning_threshold=0.2,
population_pruning_threshold=0.2,
balancing_weight=0.1,
prune_step=30,
)
# XPDEEP Model Architecture
xpdeep_model = XpdeepModel.from_torch(
fitted_schema=fit_train_dataset.fitted_schema,
feature_extraction=feature_extractor,
task_learner=task_learner,
decision_graph_parameters=explanation_architecture,
)
4. Training#
Training Specification#
from xpdeep.trainer.callbacks import EarlyStopping, Scheduler
from functools import partial
from xpdeep.metric import DictMetrics, TorchGlobalMetric, TorchLeafMetric
from torch.optim.lr_scheduler import ReduceLROnPlateau
from xpdeep.trainer.trainer import Trainer
from torchmetrics import MeanSquaredError
from torch import nn
metrics = DictMetrics(
mse=TorchGlobalMetric(metric=partial(MeanSquaredError), on_raw_data=True),
leaf_metric_mse=TorchLeafMetric(metric=partial(MeanSquaredError), on_raw_data=True)
)
callbacks = [
EarlyStopping(monitoring_metric="Total loss", mode="minimize", patience=5),
Scheduler(pre_scheduler=partial(ReduceLROnPlateau), step_method="epoch", monitoring_metric="Total loss"),
]
# XPDEEP Training Specifications
trainer = Trainer(
loss=nn.MSELoss(reduction="none"),
optimizer = partial(torch.optim.AdamW, lr=0.001, foreach=False, fused=False),
start_epoch=0,
max_epochs=60,
metrics=metrics,
callbacks=callbacks,
)
Model Training#
from sklearn.metrics import mean_squared_error, root_mean_squared_error
import torch
device = "cpu"
def train(X_train, y_train, model, loss_fn, optimizer):
size = len(X_train)
model.train()
total_loss = 0
for batch in range(size//batch_size):
X_batch, y_batch = torch.tensor(X_train[batch*batch_size:(batch+1)*batch_size,:], dtype=torch.float32).to(device), torch.tensor(y_train[batch*batch_size:(batch+1)*batch_size,:], dtype=torch.float32).to(device)
# Compute prediction error
pred = model(X_batch)
loss = loss_fn(pred, y_batch)
# Backpropagation
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
average_loss = total_loss/(size//batch_size)
return average_loss
def eval_(X_test, y_test, model, loss_fn):
model.eval()
with torch.no_grad():
X_test, y_test = torch.tensor(X_test, dtype=torch.float32).to(device), torch.tensor(y_test, dtype=torch.float32).to(device)
pred = model(X_test)
test_loss = loss_fn(pred, y_test).item()
mse = mean_squared_error(target_feature_encoder.inverse_transform(y_test), target_feature_encoder.inverse_transform(pred))
rmse = root_mean_squared_error(target_feature_encoder.inverse_transform(y_test), target_feature_encoder.inverse_transform(pred))
return target_feature_encoder.inverse_transform(pred), test_loss, mse, rmse
for t in range(epochs):
print(f"\nEpoch {t+1}\n-------------------------------")
training_loss = train(
x_train,
y_train,
sota_model,
loss_fn,
optimizer
)
_, val_loss, _, _ = eval_(
x_val,
y_val,
sota_model,
loss_fn
)
print(f"Training Loss: {training_loss}\nValidation Loss: {val_loss}")
_, _, mse_on_train, rmse_on_train = eval_(x_train, y_train, sota_model, loss_fn)
_, _, mse_on_validation, rmse_on_validation = eval_(x_val, y_val, sota_model, loss_fn)
_, _, mse_on_test, rmse_on_test = eval_(x_test, y_test, sota_model, loss_fn)
print(f"\nMSEs: "
f"\nMSE on train set : {mse_on_train}"
f"\nMSE on validation set : {mse_on_validation}"
f"\nMSE on test set : {mse_on_test}"
)
5. Explanation Generation#
from xpdeep.explain.explainer import Explainer
from xpdeep.explain.quality_metrics import Infidelity, Sensitivity
from xpdeep.explain.statistic import DictStats, HistogramStat, VarianceStat
statistics = DictStats(
histogram_target=HistogramStat(on="target", num_bins=20, num_items=1000, on_raw_data=True),
histogram_prediction=HistogramStat(on="prediction", num_bins=20, num_items=1000, on_raw_data=True),
histogram_error=HistogramStat(on="prediction_error", num_bins=20, num_items=1000, on_raw_data=True),
variance_target=VarianceStat(on="target", on_raw_data=True),
variance_prediction=VarianceStat(on="prediction", on_raw_data=True),
)
quality_metrics = [Sensitivity(), Infidelity()]
explainer = Explainer(
description_representativeness=1000, quality_metrics=quality_metrics, metrics=metrics, statistics=statistics
)
model_explanations = explainer.global_explain(
trained_model,
train_set=fit_train_dataset,
test_set=fit_test_dataset,
validation_set=fit_val_dataset,
)
print(model_explanations.visualisation_link)