Skip to content

model_parameters

Model hyperparameters.

Classes:

Name Description
ModelDecisionGraphParameters

Hyperparameters that influence the decision graph of a Xpdeep model.

ModelDecisionGraphParameters #

Hyperparameters that influence the decision graph of a Xpdeep model.

The model decision graph is characterized by a set of learnt decision and can be adjusted to the use cases to provide more interpretable explanations.

Parameters:

Name Type Description Default

graph_depth #

int | Literal[default]

This parameter defines the maximum depth of the decision graph. Increasing the value allows for more decision nodes, providing finer granularity in the decision-making process, but also adds to the complexity of the resulting graph. The depth should be set according to the desired balance between explanation detail and overall complexity. In a classification task, the default value is the logarithm of the number of classes, rounded up to the nearest integer. In a regression task, it is set to 3.

"default"

prune_step #

int | None

Pruning will be conducted every prune_step epochs, it varies within [1, max-epochs]. To disregard this parameter, and avoid pruning the model, set its value to None. This will cause the target_homogeneity_pruning_threshold and population_pruning_threshold parameters to be ignored.

None

target_homogeneity_pruning_threshold #

float | Literal[default]

To obtain a decision graph with an optimal structure, nodes that exhibit sufficient homogeneity are converted into leaves (i.e., they are pruned). Specifically, when the homogeneity of the target variable within a node exceeds a predefined threshold, that node is turned into a leaf. This threshold accepts a value in the range [0, 1]. A value of 1 disables this criterion, while the default value is set to 0.9. In the classification task, homogeneity is measured as the proportion of the majority class within a node. It is advised to set this threshold within [max_prop, 1], where max_prop denotes the proportion of the majority class in the training dataset. In the regression task, homogeneity is defined as the ratio between the variance of the target variable Y within a node and its total variance in the training set, that is: $$ 1 - \frac{\vert node \vert}{\vert train \vert }\frac{Var(Y/node)}{Var(Y/train)} $$. This parameter varies within the range [0, 1]. By default, the value is set to 0.9.

"default"

population_pruning_threshold #

float | None

To obtain a decision graph with an optimal structure, nodes containing too few individuals (i.e., insufficiently represented samples) are converted into leaves. This threshold, ranging from [0,1], defines the minimum proportion of individuals required for a node to be retained. If the proportion falls below this threshold, the node is pruned. In the context of extremely unbalanced classes, it is recommended to set this threshold as a percentage of the minority class. Setting the threshold to 0 disables this criterion. By default, its value is set to 0.01.

None

internal_model_complexity #

int

A complexity parameter, defined as an integer between 1 and 10, controls the complexity of the explainable model. A higher value results in a more complex model capable of producing more robust explanations, though at the cost of slower training compared to lower values (e.g., 1).

required

feature_extraction_output_type #

FeatureExtractionOutputType

An enum describing the output structure of the feature extraction model, required to build the explainable model.

required

target_homogeneity_weight #

float

This parameter controls the homogeneity of the target variable within the graph nodes. In classification tasks, higher values increase class purity within nodes, while in regression tasks, they reduce the variance of the target variable. It is recommended to set this parameter within the range [0,1], as excessively high values — particularly those exceeding 1 — may degrade the model's performance. Setting the parameter to 0 disables the homogeneity criterion. By default, it is set to 0.1.

0.1

discrimination_weight #

float

This parameter controls the discriminative power of the decisions made at each node of the graph. A higher value increases the model's ability to distinguish between individuals directed to the left group and those directed to the right group, by enhancing the separation based on their input features. It is recommended to set this parameter within the range [0,1], as values greater than 1 may degrade the model's performance. A value of 0 disables this discriminative constraint. By default, the value is set to 0.01.

0.01

balancing_weight #

float

This parameter regulates the balance between the proportions of the left and right groups resulting from the decisions made at each node. It is recommended to keep this parameter within the range [0,1], as excessively high values — particularly those above 1 — may hinder the model's improvement. A value of 0 disables this regulation and may lead to the formation of highly unbalanced groups in terms of size. By default, the value is set to 0.01.

0.01

frozen_model #

bool

This parameter determines how the explanations are learned in relation to the main model. If True: only the explanations are trained, based on an already trained model. The model's parameters remain frozen to preserve its initial performance. If False: the explanations are trained jointly with the model during its training process. By default, the value is set to False.

False

Attributes:

Name Type Description
additional_properties dict[str, Any]

dict() -> new empty dictionary dict(mapping) -> new dictionary initialized from a mapping object's (key, value) pairs dict(iterable) -> new dictionary initialized as if via: d = {} for k, v in iterable: d[k] = v dict(**kwargs) -> new dictionary initialized with the name=value pairs in the keyword argument list. For example: dict(one=1, two=2)

Methods:

Name Description
__attrs_post_init__

Validate config.

feature_extraction_output_type: FeatureExtractionOutputType #

__attrs_post_init__() -> None #

Validate config.

Source code in src/xpdeep/model/model_parameters.py
@no_type_check
def __attrs_post_init__(self) -> None:
    """Validate config."""
    self.feature_extraction_output_type = ModelDecisionGraphParametersFeatureExtractionOutputType(
        self.feature_extraction_output_type.name
    )

    if isinstance(self.graph_depth, int) and self.graph_depth <= 0:
        msg = f"`graph_depth` must be a positive integer but is {self.graph_depth}"
        raise ValueError(msg)
    if self.prune_step is not None and (self.prune_step <= 0 or not isinstance(self.prune_step, int)):
        msg = f"`prune_step` must be a positive integer but is {self.prune_step}"
        raise ValueError(msg)
    if (
        self.target_homogeneity_pruning_threshold is not None
        and not 0 <= self.target_homogeneity_pruning_threshold <= 1
    ):
        msg = (
            f"`target_homogeneity_pruning_threshold` must be a float in [0, 1] but is "
            f"{self.target_homogeneity_pruning_threshold}"
        )
        raise ValueError(msg)
    if self.population_pruning_threshold is not None and not 0 <= self.population_pruning_threshold <= 1:
        msg = f"`population_pruning_threshold` must be a float in [0, 1] but is {self.population_pruning_threshold}"
        raise ValueError(msg)
    if self.internal_model_complexity not in {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}:
        msg = (
            f"`internal model complexity` must be an integer between 1 and 10 but is "
            f"{self.internal_model_complexity}"
        )
        raise ValueError(msg)

    if self.target_homogeneity_weight < 0:
        msg = (
            f"`target_homogeneity_weight` must be greater than or equal to 0 but is "
            f"{self.target_homogeneity_weight}"
        )
        raise ValueError(msg)

    if self.discrimination_weight < 0:
        msg = f"`target_homogeneity_weight` must be greater than or equal to 0 but is {self.discrimination_weight}"
        raise ValueError(msg)

    if self.balancing_weight < 0:
        msg = f"`target_homogeneity_weight` must be greater than or equal to 0 but is {self.balancing_weight}"
        raise ValueError(msg)