Skip to content

How to Compute Explanations#

With an explainable dataset FittedParquetDataset and a trained model artifact TrainedModelArtifact, Xpdeep generates two types of explanations through the interactive XpViz interface:

  • Model functioning explanations, represented by a graph titled 'Model Decision Graph', which illustrate the functioning of the learned model, the set of decisions made to accomplish the task, the key features involved and their importance, the reliability and robustness of the explanations, model performance, and various metrics and statistics related to explanation quality.
  • Inference causal explanations, represented by a graph titled 'Inference Graph', details the entire set of decisions involved in generating a prediction for an individual or group of individuals. They highlight the key causal features and their importance, the explanations' reliability and robustness, and various metrics and statistics related to the resulting prediction.

For more information on Xpdeep explanations, refer to explanations section.

Tip

A video tutorial is available on XpViz, to guide you through the XpViz interface and clearly introduce all the explanations.

1. Build the Explainer#

Given a FittedParquetDataset, and a trained XpdeepModel, the Explainer interface allows to:

  • Generate the model explanations into the Model Decision Graph.
  • Generate predictions with their explanations into the Inference Graph.

Future Release

Generate, like in a standard deep model, predictions without any explanations.

2. Get Explanations with the Explainer#

Given the TrainedModelArtifact, that is obtained once the model trained, let's detail the Explainer characteristics and required input parameters:

  • description_representativeness: each node/leaf of the graph corresponds to a distribution of individuals. To provide explanations, metrics, or statistics for each node/leaf, a representative subset is selected. This parameter determines the size of that representative subset. The larger the subset, the more precise the explanations, statistics, and metrics will be, but the computational complexity will also increase.

  • windows_size: an optional parameter specific to time series data. It allows you to specify the window size for warping and synchronizing temporal data, for example, when using the Dynamic Time Warping (DTW) metric.

  • quality_metrics: this parameter specifies the metrics used to evaluate the quality of the explanations provided. Xpdeep provides two metrics available in Captum:

    • Infidelity (or Reliability): evaluates the reliability of the provided explanations, by measuring how accurately the attributions reflect the output results

    • Sensitivity (or Robustness): measures the stability or robustness of the explanation in response to minor perturbations in the surrounding area.

    Future Release

    Future Xpdeep release will integrate more explanation quality metrics.

  • metrics: contains a dictionary of customizable metrics or visualization alongside the explanations in the XpViz interface. This parameter is identical to the one used during the training stage, allowing you to add your own metrics as partial to assess, for example, accuracy or F1 score.

  • statistics: contains a dictionary of customizable statistics for visualization alongside the explanations in the XpViz interface. See API reference for a list of available statistics. Statistics must be instantiated and not partial.

Tip

Additional parameters are available, like a seed or an explanation batch_size, see API reference.

3. In Practice#

Let's build the explainer first.

from xpdeep.explain.explainer import Explainer
from xpdeep.explain.quality_metrics import Sensitivity, Infidelity
from torchmetrics import Accuracy, F1Score
from functools import partial
from xpdeep.metrics.metric import DictMetrics
from xpdeep.explain.statistic import DictStats, VarianceStat

quality_metrics = [Sensitivity(), Infidelity()]

metrics = DictMetrics(
    accuracy=partial(Accuracy, task="multiclass", num_classes=3),
    f1_score=partial(F1Score, task="multiclass", num_classes=3),
)

statistics = DictStats(
    variance_target=VarianceStat(on="target", on_raw_data=True),
    variance_prediction=VarianceStat(on="prediction", on_raw_data=True)                      
)

explainer = Explainer(
    description_representativeness=100,
    quality_metrics=quality_metrics,
    metrics=metrics,
    statistics=statistics
)
👀 Full file preview
from functools import partial

from prepare_dataset import fitted_train_dataset, fitted_validation_dataset
from torchmetrics import Accuracy, F1Score
from train_model import trained_model

from xpdeep.explain.explainer import Explainer
from xpdeep.explain.quality_metrics import Infidelity, Sensitivity
from xpdeep.explain.statistic import DictStats, VarianceStat
from xpdeep.filtering.criteria import NumericalCriterion
from xpdeep.filtering.filter import Filter
from xpdeep.metrics.metric import DictMetrics

quality_metrics = [Sensitivity(), Infidelity()]

metrics = DictMetrics(
    accuracy=partial(Accuracy, task="multiclass", num_classes=3),
    f1_score=partial(F1Score, task="multiclass", num_classes=3),
)

statistics = DictStats(
    variance_target=VarianceStat(on="target", on_raw_data=True),
    variance_prediction=VarianceStat(on="prediction", on_raw_data=True),
)

explainer = Explainer(
    description_representativeness=100, quality_metrics=quality_metrics, metrics=metrics, statistics=statistics
)

# Model decision
model_decision = explainer.global_explain(
    trained_model=trained_model, train_set=fitted_train_dataset, validation_set=fitted_validation_dataset
)

# Causal explanations
dataset_filter = Filter("my_filter", fitted_validation_dataset)
dataset_filter.add_criteria(NumericalCriterion(fitted_validation_dataset.fitted_schema["petal_length"], max_=3.0))

causal_explanations = explainer.local_explain(
    trained_model=trained_model, train_set=fitted_train_dataset, dataset_filter=dataset_filter
)

4. Model Functioning Explanations#

Here, we provide the Explainer with the training set to be explained, along with any optional subsets for computing metrics and statistics on unseen data.

model_decision = explainer.global_explain(trained_model=trained_model, train_set=train_set, validation_set=validation_set)

Tip

You can print the explanation to display the clickable visualization link. It will show you to XpViz platform on your browser.

print(model_decision)

5. Inference and their Causal Explanations#

Similarly to the model functioning explanations, you can use the Explainer to generate predictions and their causal explanations.

Use a Dataset Filter#

The dataset_filter parameter specifies the samples for which to generate inferences and their causal explanations. It applies a filter to a FittedParquetDataset to select the desired samples, which can be based on columns (features) or on rows (samples).

Tip

If you want causal explanations on the whole dataset, you only need to instantiate a Filter without criteria.

In addition, if you want a specific filter but no such a filer is implemented on the client, please filter data on your side, then create a FittedParquetDataset with this subset.

On the following example, the filter is applied on features. We use a NumericalCriterion for the NumericalFeature "petal_length", and a CategoricalCriterion for the CategoricalFeature "flower_type". In addition, we use a filter on the first 5 rows.

from xpdeep.filtering.filter import Filter
from xpdeep.filtering.criteria import CategoricalCriterion, NumericalCriterion

dataset_filter = Filter("my_filter", fitted_parquet_dataset=train_parquet_dataset_fitted, row_indexes=[0, 1, 2, 3, 4])
dataset_filter.add_criteria(
    NumericalCriterion(train_parquet_dataset_fitted.fitted_schema["petal_length"], max_=2.0),
    CategoricalCriterion(
        train_parquet_dataset_fitted.fitted_schema["flower_type"], 
        categories=["Versicolor", "Setosa"]
    ),
)
👀 Full file preview
from functools import partial

from prepare_dataset import fitted_train_dataset, fitted_validation_dataset
from torchmetrics import Accuracy, F1Score
from train_model import trained_model

from xpdeep.explain.explainer import Explainer
from xpdeep.explain.quality_metrics import Infidelity, Sensitivity
from xpdeep.explain.statistic import DictStats, VarianceStat
from xpdeep.filtering.criteria import NumericalCriterion
from xpdeep.filtering.filter import Filter
from xpdeep.metrics.metric import DictMetrics

quality_metrics = [Sensitivity(), Infidelity()]

metrics = DictMetrics(
    accuracy=partial(Accuracy, task="multiclass", num_classes=3),
    f1_score=partial(F1Score, task="multiclass", num_classes=3),
)

statistics = DictStats(
    variance_target=VarianceStat(on="target", on_raw_data=True),
    variance_prediction=VarianceStat(on="prediction", on_raw_data=True),
)

explainer = Explainer(
    description_representativeness=100, quality_metrics=quality_metrics, metrics=metrics, statistics=statistics
)

# Model decision
model_decision = explainer.global_explain(
    trained_model=trained_model, train_set=fitted_train_dataset, validation_set=fitted_validation_dataset
)

# Causal explanations
dataset_filter = Filter("my_filter", fitted_validation_dataset)
dataset_filter.add_criteria(NumericalCriterion(fitted_validation_dataset.fitted_schema["petal_length"], max_=3.0))

causal_explanations = explainer.local_explain(
    trained_model=trained_model, train_set=fitted_train_dataset, dataset_filter=dataset_filter
)

Generate Inferences and their Causal Explanations#

Once the subset for inference is selected (using the filters), the quality metrics of the explanations are evaluated based on the training set. After this step, you can proceed to generate the inferences and their causal explanations.

causal_explanations = explainer.local_explain(trained_model=trained_model, dataset_filter=dataset_filter, train_set=train_set)
👀 Full file preview
from functools import partial

from prepare_dataset import fitted_train_dataset, fitted_validation_dataset
from torchmetrics import Accuracy, F1Score
from train_model import trained_model

from xpdeep.explain.explainer import Explainer
from xpdeep.explain.quality_metrics import Infidelity, Sensitivity
from xpdeep.explain.statistic import DictStats, VarianceStat
from xpdeep.filtering.criteria import NumericalCriterion
from xpdeep.filtering.filter import Filter
from xpdeep.metrics.metric import DictMetrics

quality_metrics = [Sensitivity(), Infidelity()]

metrics = DictMetrics(
    accuracy=partial(Accuracy, task="multiclass", num_classes=3),
    f1_score=partial(F1Score, task="multiclass", num_classes=3),
)

statistics = DictStats(
    variance_target=VarianceStat(on="target", on_raw_data=True),
    variance_prediction=VarianceStat(on="prediction", on_raw_data=True),
)

explainer = Explainer(
    description_representativeness=100, quality_metrics=quality_metrics, metrics=metrics, statistics=statistics
)

# Model decision
model_decision = explainer.global_explain(
    trained_model=trained_model, train_set=fitted_train_dataset, validation_set=fitted_validation_dataset
)

# Causal explanations
dataset_filter = Filter("my_filter", fitted_validation_dataset)
dataset_filter.add_criteria(NumericalCriterion(fitted_validation_dataset.fitted_schema["petal_length"], max_=3.0))

causal_explanations = explainer.local_explain(
    trained_model=trained_model, train_set=fitted_train_dataset, dataset_filter=dataset_filter
)

Tip

You can print the explanation to display the clickable visualization link. It will redirect you to XpViz platform on your browser.

print(causal_explanations)

Future Release

The explanation process will be interruptible with ctrl+C