Skip to content

model_parameters

Model hyperparameters.

Classes:

Name Description
ModelDecisionGraphParameters

Hyperparameters that influence the decision graph of a Xpdeep model.

Attributes:

Name Type Description
logger

logger = logging.getLogger(__name__) #

ModelDecisionGraphParameters #

Hyperparameters that influence the decision graph of a Xpdeep model.

The model decision graph is characterized by a set of learnt decision and can be adjusted to the use cases to provide more interpretable explanations.

Parameters:

Name Type Description Default

graph_depth #

int | None

Set the maximum depth of the decision graph. A higher value allows for more decision nodes and finer granularity but also increases graph complexity. When set to None, the backend selects a default value: ceil(log2(n_classes)) for classification tasks and 3 for regression tasks.

None

prune_step #

int | None

Trigger graph pruning every prune_step epochs. Valid values are in [1, max_epochs]. Set to None to prune only on the last epoch.

None

population_pruning_threshold #

float

To obtain a decision graph with an optimal structure, nodes containing too few individuals (i.e., insufficiently represented samples) are converted into leaves. This threshold, ranging from [0,1], defines the minimum proportion of individuals required for a node to be retained. If the proportion falls below this threshold, the node is pruned. In the context of extremely unbalanced classes, it is recommended to set this threshold as a percentage of the minority class. Setting the threshold to 0 disables this criterion. By default, its value is set to 0.01.

0.01

internal_model_complexity #

int

A complexity parameter, defined as an integer between 1 and 10, controls the complexity of the explainable model. A higher value results in a more complex model capable of producing more robust explanations, though at the cost of slower training compared to lower values (e.g., 1).

1

feature_extraction_output_type #

FeatureExtractionOutputType | Literal[
"IMAGE_MATRIX | IMAGE_TENSOR | MATRIX | TEMPORAL_MATRIX | VECTOR | VIDEO_TENSOR | DFINE_MATRIX"

] Describe the output structure of the feature extraction model. This value is required to build the explainable model.

required

target_homogeneity_weight #

float

Encourage homogeneity of the target variable within nodes. In classification tasks, higher values increase class purity within nodes; in regression tasks, they reduce target variance. Recommended values lie in [0, 1]. Setting the weight to 0 disables the homogeneity constraint. The default is 0.1.

0.1

discrimination_weight #

float

Encourage discriminative splits at each node. Higher values enforce stronger separation between left and right groups based on input features. Recommended values lie in [0, 1]; set to 0 to disable. The default is 0.01.

0.01

balancing_weight #

float

Encourage balanced proportions between left and right groups after each decision. Recommended values lie in [0, 1]; increasing the value can prevent the graph from collapsing into very few leaves. Set to 0 to disable. The default is 0.01.

0.01

frozen_model #

bool

Select how explanations are learned relative to the main model. If True, only the explanations are trained on an already trained model and the model weights stay frozen. If False, explanations are trained jointly with the model. The default is False.

False

Attributes:

Name Type Description
target_homogeneity_pruning_threshold float | Literal['default']

Deprecated, to be removed later.

reset_optimizer_on_pruning bool, default True

Deprecated, to be removed later.

additional_properties str

dict() -> new empty dictionary dict(mapping) -> new dictionary initialized from a mapping object's (key, value) pairs dict(iterable) -> new dictionary initialized as if via: d = {} for k, v in iterable: d[k] = v dict(**kwargs) -> new dictionary initialized with the name=value pairs in the keyword argument list. For example: dict(one=1, two=2)

Methods:

Name Description
__attrs_post_init__

Validate config.

feature_extraction_output_type: FeatureExtractionOutputType | Literal['IMAGE_MATRIX | IMAGE_TENSOR | MATRIX | TEMPORAL_MATRIX | VECTOR | VIDEO_TENSOR | DFINE_MATRIX'] #

graph_depth: int | None = None #

prune_step: int | None = None #

internal_model_complexity: int = 1 #

frozen_model: bool = False #

target_homogeneity_pruning_threshold: float | Literal['default'] = field(init=False, metadata={'doc': False}, default=0.0) #

reset_optimizer_on_pruning: bool = field(init=False, metadata={'doc': False}, default=True) #

__attrs_post_init__() -> None #

Validate config.

Source code in src/xpdeep/model/model_parameters.py
@no_type_check
def __attrs_post_init__(self) -> None:  # noqa: C901
    """Validate config."""
    if isinstance(self.feature_extraction_output_type, FeatureExtractionOutputType):
        self.feature_extraction_output_type = FeatureExtractionOutputType(self.feature_extraction_output_type)
    if isinstance(self.graph_depth, int) and self.graph_depth <= 0:
        msg = f"`graph_depth` must be a positive integer but is {self.graph_depth}"
        raise ValueError(msg)
    if self.prune_step is not None and (self.prune_step <= 0 or not isinstance(self.prune_step, int)):
        msg = f"`prune_step` must be a positive integer but is {self.prune_step}"
        raise ValueError(msg)
    if self.population_pruning_threshold is not None and not 0 <= self.population_pruning_threshold <= 1:
        msg = f"`population_pruning_threshold` must be a float in [0, 1] but is {self.population_pruning_threshold}"
        raise ValueError(msg)
    if self.internal_model_complexity not in {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}:
        msg = (
            f"`internal model complexity` must be an integer between 1 and 10 but is "
            f"{self.internal_model_complexity}"
        )
        raise ValueError(msg)

    if self.target_homogeneity_weight < 0:
        msg = (
            f"`target_homogeneity_weight` must be greater than or equal to 0 but is "
            f"{self.target_homogeneity_weight}"
        )
        raise ValueError(msg)

    if self.discrimination_weight < 0:
        msg = f"`target_homogeneity_weight` must be greater than or equal to 0 but is {self.discrimination_weight}"
        raise ValueError(msg)

    if self.balancing_weight < 0:
        msg = f"`target_homogeneity_weight` must be greater than or equal to 0 but is {self.balancing_weight}"
        raise ValueError(msg)
    if self.target_homogeneity_pruning_threshold != 0.0:
        logger.warning(
            "`target_homogeneity_pruning_threshold` is deprecated and will be removed in the next release."
        )
    if not self.reset_optimizer_on_pruning:
        logger.warning("`reset_optimizer_on_pruning` is deprecated and  will be removed in the next release.")