How to Compute Explanations#
With an explainable dataset FittedParquetDataset
and a trained model artifact TrainedModelArtifact
, Xpdeep generates two types of explanations through the interactive
XpViz
interface:
- Model functioning explanations, represented by a graph titled 'Model Decision Graph', which illustrate the functioning of the learned model, the set of decisions made to accomplish the task, the key features involved and their importance, the reliability and robustness of the explanations, model performance, and various metrics and statistics related to explanation quality.
- Inference causal explanations, represented by a graph titled 'Inference Graph', details the entire set of decisions involved in generating a prediction for an individual or group of individuals. They highlight the key causal features and their importance, the explanations' reliability and robustness, and various metrics and statistics related to the resulting prediction.
For more information on Xpdeep explanations, refer to explanations section.
Tip
A video tutorial is available on XpViz
, to guide you through the XpViz
interface and clearly introduce all the
explanations.
1. Build the Explainer
#
Given a FittedParquetDataset
, and a trained XpdeepModel
, the Explainer
interface allows to:
- Generate the model explanations into the Model Decision Graph.
- Generate predictions with their explanations into the Inference Graph.
Future Release
Generate, like in a standard deep model, predictions without any explanations.
2. Get Explanations with the Explainer
#
Given the TrainedModelArtifact
, that is obtained once the model trained, let's detail the Explainer
characteristics
and required input parameters:
-
description_representativeness
: each node/leaf of the graph corresponds to a distribution of individuals. To provide explanations, metrics, or statistics for each node/leaf, a representative subset is selected. This parameter determines the size of that representative subset. The larger the subset, the more precise the explanations, statistics, and metrics will be, but the computational complexity will also increase. -
windows_size
: an optional parameter specific to time series data. It allows you to specify the window size for warping and synchronizing temporal data, for example, when using the Dynamic Time Warping (DTW) metric. -
quality_metrics
: this parameter specifies the metrics used to evaluate the quality of the explanations provided. Xpdeep provides two metrics available in Captum:-
Infidelity (or Reliability): evaluates the reliability of the provided explanations, by measuring how accurately the attributions reflect the output results
-
Sensitivity (or Robustness): measures the stability or robustness of the explanation in response to minor perturbations in the surrounding area.
Future Release
Future Xpdeep release will integrate more explanation quality metrics.
-
-
metrics
: contains a dictionary of customizable metrics or visualization alongside the explanations in theXpViz
interface. This parameter is identical to the one used during the training stage, allowing you to add your own metrics as partial to assess, for example, accuracy or F1 score. -
statistics
: contains a dictionary of customizable statistics for visualization alongside the explanations in theXpViz
interface. See API reference for a list of available statistics. Statistics must be instantiated and not partial.
Tip
Additional parameters are available, like a seed
or an explanation batch_size
, see API reference.
3. In Practice#
Let's build the explainer first.
from xpdeep.explain.explainer import Explainer
from xpdeep.explain.quality_metrics import Sensitivity, Infidelity
from torchmetrics import Accuracy, F1Score
from functools import partial
from xpdeep.metrics.metric import DictMetrics
from xpdeep.explain.statistic import DictStats, VarianceStat
quality_metrics = [Sensitivity(), Infidelity()]
metrics = DictMetrics(
accuracy=partial(Accuracy, task="multiclass", num_classes=3),
f1_score=partial(F1Score, task="multiclass", num_classes=3),
)
statistics = DictStats(
variance_target=VarianceStat(on="target", on_raw_data=True),
variance_prediction=VarianceStat(on="prediction", on_raw_data=True)
)
explainer = Explainer(
description_representativeness=100,
quality_metrics=quality_metrics,
metrics=metrics,
statistics=statistics
)
👀 Full file preview
from functools import partial
from prepare_dataset import fitted_train_dataset, fitted_validation_dataset
from torchmetrics import Accuracy, F1Score
from train_model import trained_model
from xpdeep.explain.explainer import Explainer
from xpdeep.explain.quality_metrics import Infidelity, Sensitivity
from xpdeep.explain.statistic import DictStats, VarianceStat
from xpdeep.filtering.criteria import NumericalCriterion
from xpdeep.filtering.filter import Filter
from xpdeep.metrics.metric import DictMetrics
quality_metrics = [Sensitivity(), Infidelity()]
metrics = DictMetrics(
accuracy=partial(Accuracy, task="multiclass", num_classes=3),
f1_score=partial(F1Score, task="multiclass", num_classes=3),
)
statistics = DictStats(
variance_target=VarianceStat(on="target", on_raw_data=True),
variance_prediction=VarianceStat(on="prediction", on_raw_data=True),
)
explainer = Explainer(
description_representativeness=100, quality_metrics=quality_metrics, metrics=metrics, statistics=statistics
)
# Model decision
model_decision = explainer.global_explain(
trained_model=trained_model, train_set=fitted_train_dataset, validation_set=fitted_validation_dataset
)
# Causal explanations
dataset_filter = Filter("my_filter", fitted_validation_dataset)
dataset_filter.add_criteria(NumericalCriterion(fitted_validation_dataset.fitted_schema["petal_length"], max_=3.0))
causal_explanations = explainer.local_explain(
trained_model=trained_model, train_set=fitted_train_dataset, dataset_filter=dataset_filter
)
4. Model Functioning Explanations#
Here, we provide the Explainer
with the training set to be explained, along with any optional subsets for computing
metrics and statistics on unseen data.
model_decision = explainer.global_explain(trained_model=trained_model, train_set=train_set, validation_set=validation_set)
Tip
You can print the explanation to display the clickable visualization link. It will show you to XpViz
platform
on your browser.
5. Inference and their Causal Explanations#
Similarly to the model functioning explanations, you can use the Explainer
to generate predictions and their causal
explanations.
Use a Dataset Filter#
The dataset_filter
parameter specifies the samples for which to generate inferences and their causal explanations.
It applies a filter to a FittedParquetDataset
to select the desired samples, which can be based on columns
(features) or on rows (samples).
Tip
If you want causal explanations on the whole dataset, you only need to instantiate a Filter
without criteria.
In addition, if you want a specific filter but no such a filer is implemented on the client, please filter data on
your side, then create a FittedParquetDataset
with this subset.
On the following example, the filter is applied on features. We use a NumericalCriterion
for the NumericalFeature
"petal_length", and a CategoricalCriterion
for the CategoricalFeature
"flower_type".
In addition, we use a filter on the first 5 rows.
from xpdeep.filtering.filter import Filter
from xpdeep.filtering.criteria import CategoricalCriterion, NumericalCriterion
dataset_filter = Filter("my_filter", fitted_parquet_dataset=train_parquet_dataset_fitted, row_indexes=[0, 1, 2, 3, 4])
dataset_filter.add_criteria(
NumericalCriterion(train_parquet_dataset_fitted.fitted_schema["petal_length"], max_=2.0),
CategoricalCriterion(
train_parquet_dataset_fitted.fitted_schema["flower_type"],
categories=["Versicolor", "Setosa"]
),
)
👀 Full file preview
from functools import partial
from prepare_dataset import fitted_train_dataset, fitted_validation_dataset
from torchmetrics import Accuracy, F1Score
from train_model import trained_model
from xpdeep.explain.explainer import Explainer
from xpdeep.explain.quality_metrics import Infidelity, Sensitivity
from xpdeep.explain.statistic import DictStats, VarianceStat
from xpdeep.filtering.criteria import NumericalCriterion
from xpdeep.filtering.filter import Filter
from xpdeep.metrics.metric import DictMetrics
quality_metrics = [Sensitivity(), Infidelity()]
metrics = DictMetrics(
accuracy=partial(Accuracy, task="multiclass", num_classes=3),
f1_score=partial(F1Score, task="multiclass", num_classes=3),
)
statistics = DictStats(
variance_target=VarianceStat(on="target", on_raw_data=True),
variance_prediction=VarianceStat(on="prediction", on_raw_data=True),
)
explainer = Explainer(
description_representativeness=100, quality_metrics=quality_metrics, metrics=metrics, statistics=statistics
)
# Model decision
model_decision = explainer.global_explain(
trained_model=trained_model, train_set=fitted_train_dataset, validation_set=fitted_validation_dataset
)
# Causal explanations
dataset_filter = Filter("my_filter", fitted_validation_dataset)
dataset_filter.add_criteria(NumericalCriterion(fitted_validation_dataset.fitted_schema["petal_length"], max_=3.0))
causal_explanations = explainer.local_explain(
trained_model=trained_model, train_set=fitted_train_dataset, dataset_filter=dataset_filter
)
Generate Inferences and their Causal Explanations#
Once the subset for inference is selected (using the filters), the quality metrics of the explanations are evaluated based on the training set. After this step, you can proceed to generate the inferences and their causal explanations.
causal_explanations = explainer.local_explain(trained_model=trained_model, dataset_filter=dataset_filter, train_set=train_set)
👀 Full file preview
from functools import partial
from prepare_dataset import fitted_train_dataset, fitted_validation_dataset
from torchmetrics import Accuracy, F1Score
from train_model import trained_model
from xpdeep.explain.explainer import Explainer
from xpdeep.explain.quality_metrics import Infidelity, Sensitivity
from xpdeep.explain.statistic import DictStats, VarianceStat
from xpdeep.filtering.criteria import NumericalCriterion
from xpdeep.filtering.filter import Filter
from xpdeep.metrics.metric import DictMetrics
quality_metrics = [Sensitivity(), Infidelity()]
metrics = DictMetrics(
accuracy=partial(Accuracy, task="multiclass", num_classes=3),
f1_score=partial(F1Score, task="multiclass", num_classes=3),
)
statistics = DictStats(
variance_target=VarianceStat(on="target", on_raw_data=True),
variance_prediction=VarianceStat(on="prediction", on_raw_data=True),
)
explainer = Explainer(
description_representativeness=100, quality_metrics=quality_metrics, metrics=metrics, statistics=statistics
)
# Model decision
model_decision = explainer.global_explain(
trained_model=trained_model, train_set=fitted_train_dataset, validation_set=fitted_validation_dataset
)
# Causal explanations
dataset_filter = Filter("my_filter", fitted_validation_dataset)
dataset_filter.add_criteria(NumericalCriterion(fitted_validation_dataset.fitted_schema["petal_length"], max_=3.0))
causal_explanations = explainer.local_explain(
trained_model=trained_model, train_set=fitted_train_dataset, dataset_filter=dataset_filter
)
Tip
You can print the explanation to display the clickable visualization link. It will redirect you to XpViz
platform
on your browser.
Future Release
The explanation process will be interruptible with ctrl+C