Skip to content

explainer

How to explain a trained model.

Classes:

Name Description
Explainer

Explain a XpdeepModel.

Explainer #

Explain a XpdeepModel.

Parameters:

Name Type Description Default

description_representativeness #

int

A parameter governing the explanation quality, the greater, the better, but it will be slower to compute.

required

quality_metrics #

list[QualityMetrics]

A list of quality metrics to compute, like Sensitivity or Infidelity.

required

window_size #

int | None

DTW parameter windows (proportion %)

None

metrics #

DictMetrics | None

A list of metrics to compute along with the explanation (F1 score etc.)

None

statistics #

DictMetrics | None

A list of statistics to compute along with the explanation (Variance on targets etc.)

None

batch_size #

int | None

The batch size to use during explanation. Default to None.

None

explanation_seed #

int | None

The seed to use during explanation. Default to None.

None

batch_generation_seed #

int | None

The seed to use during batch generation. Default to None.

None

Methods:

Name Description
__attrs_post_init__

Validate config.

local_explain

Create a causal explanation from trained model.

global_explain

Compute model decision on a trained model.

Attributes:

Name Type Description
description_representativeness int
quality_metrics list[QualityMetrics]
window_size int | None
metrics DictMetrics
statistics DictStats
batch_size int | None
explanation_seed int | None
batch_generation_seed int | None

description_representativeness: int #

quality_metrics: list[QualityMetrics] #

window_size: int | None = None #

metrics: DictMetrics = _attrs_field(default=(DictMetrics())) #

statistics: DictStats = _attrs_field(default=(DictStats())) #

batch_size: int | None = None #

explanation_seed: int | None = None #

batch_generation_seed: int | None = None #

__attrs_post_init__() -> None #

Validate config.

Source code in src/xpdeep/explain/explainer.py
def __attrs_post_init__(self) -> None:
    """Validate config."""
    if self.description_representativeness < 0:
        msg = (
            f"`description_representativeness` must be a positive integer but is "
            f"{self.description_representativeness}"
        )
        raise ValueError(msg)
    if self.window_size is not None and self.window_size < 0:
        msg = f"`window_size` must be a positive integer but is {self.window_size}"
        raise ValueError(msg)

local_explain(trained_model: TrainedModelArtifact, train_set: FittedParquetDataset, dataset_filter: Filter) -> ExplanationArtifact #

Create a causal explanation from trained model.

Parameters:

Name Type Description Default
trained_model #
TrainedModelArtifact

A model trained via the trainer interface

required
train_set #
FittedParquetDataset

A dataset representing a train split.

required
dataset_filter #
Filter

A filter used to filter the dataset and get samples to explain.

required

Returns:

Type Description
ExplanationResultsModel

The causal explanation results, containing the result as json.

Source code in src/xpdeep/explain/explainer.py
def local_explain(
    self,
    trained_model: TrainedModelArtifact,
    train_set: FittedParquetDataset,
    dataset_filter: Filter,
) -> ExplanationArtifact:
    """Create a causal explanation from trained model.

    Parameters
    ----------
    trained_model : TrainedModelArtifact
        A model trained via the trainer interface
    train_set : FittedParquetDataset
        A dataset representing a train split.
    dataset_filter : Filter
        A filter used to filter the dataset and get samples to explain.

    Returns
    -------
    ExplanationResultsModel
        The causal explanation results, containing the result as json.
    """
    best_checkpoint = trained_model.best_checkpoint

    if best_checkpoint is None:
        msg = "Something went wrong. Best checkpoint not found."
        raise ApiError(msg)

    dataset_filter.apply()

    local_explain_pipeline_input = LocalExplainPipelineInput(
        type_="LOCAL_EXPLAIN",
        explanation_batch_config=self._build_batch_config(),
        explanation_config=self._build_explanation_config(),
        train_dataset_artifact_id=train_set.artifact_id(),
        inference_dataset=dataset_filter.to_model(),
        trained_model_checkpoint_id=best_checkpoint.id,
        metrics=self.metrics.to_model(),  # type: ignore[arg-type]
        statistics=self.statistics.to_model(),  # type: ignore[arg-type]
    )

    client_factory = ClientFactory.CURRENT.get()

    with client_factory() as client:
        local_explanation_pipeline = handle_api_validation_errors(
            launch_pipeline.sync(
                project_id=Project.CURRENT.get().model.id,
                client=client,
                body=local_explain_pipeline_input,
            ),
        )

    return explanation_computing.get_pipeline_result(local_explanation_pipeline.id)

global_explain(trained_model: TrainedModelArtifact, train_set: FittedParquetDataset, test_set: FittedParquetDataset | None = None, validation_set: FittedParquetDataset | None = None) -> ExplanationArtifact #

Compute model decision on a trained model.

Parameters:

Name Type Description Default
trained_model #
TrainedModelArtifact

A model trained via the trainer interface.

required
train_set #
FittedParquetDataset

A dataset representing a train split.

required
test_set #
FittedParquetDataset | None

A dataset representing a test split, used to optionally compute split statistics.

None
validation_set #
FittedParquetDataset | None

A dataset representing a validation split, used to optionally compute split statistics.

None

Returns:

Type Description
ExplanationResultsModel

The model decision results, containing the result as json.

Source code in src/xpdeep/explain/explainer.py
def global_explain(
    self,
    trained_model: TrainedModelArtifact,
    train_set: FittedParquetDataset,
    test_set: FittedParquetDataset | None = None,
    validation_set: FittedParquetDataset | None = None,
) -> ExplanationArtifact:
    """Compute model decision on a trained model.

    Parameters
    ----------
    trained_model : TrainedModelArtifact
        A model trained via the trainer interface.
    train_set : FittedParquetDataset
        A dataset representing a train split.
    test_set : FittedParquetDataset | None
        A dataset representing a test split, used to optionally compute split statistics.
    validation_set : FittedParquetDataset | None
        A dataset representing a validation split, used to optionally compute split statistics.

    Returns
    -------
    ExplanationResultsModel
        The model decision results, containing the result as json.
    """
    best_checkpoint = trained_model.best_checkpoint

    if best_checkpoint is None:
        msg = "Something went wrong. Best checkpoint not found."
        raise ApiError(msg)

    global_explain_pipeline_input = GlobalExplainPipelineInput(
        type_="GLOBAL_EXPLAIN",
        explanation_batch_config=self._build_batch_config(),
        explanation_config=self._build_explanation_config(),
        train_dataset_artifact_id=train_set.artifact_id(),
        trained_model_checkpoint_id=best_checkpoint.id,
        metrics=self.metrics.to_model(),
        statistics=self.statistics.to_model(),
        test_dataset_artifact_id=test_set.artifact_id() if test_set is not None else None,
        validation_dataset_artifact_id=validation_set.artifact_id() if validation_set is not None else None,
    )

    client_factory = ClientFactory.CURRENT.get()

    with client_factory() as client:
        global_explanation_job = handle_api_validation_errors(
            launch_pipeline.sync(
                project_id=Project.CURRENT.get().model.id,
                client=client,
                body=global_explain_pipeline_input,
            ),
        )

    return explanation_computing.get_pipeline_result(global_explanation_job.id)