Skip to content

filter

Module defining the filter object.

Classes:

Name Description
Filter

Filter class, used to filter a dataset for causal explanation on a subset.

Filter(name: str, fitted_parquet_dataset: FittedParquetDataset, criteria: Sequence[NumericalCriterion | CategoricalCriterion | MultivariateCriterion | TimeseriesBoxCriterion] = (), row_indexes: Sequence[int] = (), *, min_index: int | None = None, max_index: int | None = None) #

Filter class, used to filter a dataset for causal explanation on a subset.

"dataset", "min_index", "max_index" attributes should never be updated once the instance is created. It is still safe to update other attributes.

Initialize the filter object.

Parameters:

Name Type Description Default

name #

str

The filter name.

required

fitted_parquet_dataset #

FittedParquetDataset

The dataset to filter will be applied to. Do not update after Object initialization.

required

criteria #

(Sequence[NumericalCriterion | CategoricalCriterion | MultivariateCriterion | TimeseriesBoxCriterion],)
()

default #

An optional list of filter criteria to filter by feature.

required

row_indexes #

Sequence[int]

An optional list of row indexes to filter by sample.

().

min_index #

int | None

Filter rows with lower indexes than min_index. Do not update after Object initialization.

None.

max_index #

int | None

Filter rows with greater indexes than max_index. Do not update after Object initialization.

None.

Methods:

Name Description
__setattr__

Set attribute.

add_criteria

Add many criteria.

save

Save the Filter remotely.

load_all

List all filters of the current project.

get_by_id

Get Filter by its ID.

get_by_name

Get Filter by its name.

delete

Delete the current object remotely.

__len__

Get filter's result size.

Attributes:

Name Type Description
name
criteria
row_indexes
dataset
min_index
max_index
id str

Get id.

Source code in src/xpdeep/filtering/filter.py
def __init__(  # noqa:PLR0913
    self,
    name: str,
    fitted_parquet_dataset: FittedParquetDataset,
    criteria: Sequence[
        NumericalCriterion | CategoricalCriterion | MultivariateCriterion | TimeseriesBoxCriterion
    ] = (),
    row_indexes: Sequence[int] = (),
    *,
    min_index: int | None = None,
    max_index: int | None = None,
) -> None:
    """
    Initialize the filter object.

    Parameters
    ----------
    name : str
        The filter name.
    fitted_parquet_dataset : FittedParquetDataset
        The dataset to filter will be applied to. Do not update after Object initialization.
    criteria : Sequence[NumericalCriterion | CategoricalCriterion | MultivariateCriterion | TimeseriesBoxCriterion],
    default ()
        An optional list of filter criteria to filter by feature.
    row_indexes : Sequence[int], default ().
        An optional list of row indexes to filter by sample.
    min_index : int | None, default None.
        Filter rows with lower indexes than min_index. Do not update after Object initialization.
    max_index : int | None, default None.
        Filter rows with greater indexes than max_index. Do not update after Object initialization.
    """
    self.name = name
    self.criteria = list(criteria)
    self.row_indexes = row_indexes
    self.dataset = fitted_parquet_dataset
    self.min_index = min_index
    self.max_index = max_index

name = name #

criteria = list(criteria) #

row_indexes = row_indexes #

dataset = fitted_parquet_dataset #

min_index = min_index #

max_index = max_index #

id: str #

Get id.

__setattr__(attr: str, value: object) -> None #

Set attribute.

Source code in src/xpdeep/filtering/filter.py
@initialized_client_verification
@initialized_project_verification
def __setattr__(self, attr: str, value: object) -> None:
    """Set attribute."""
    # If the instance does not have attribute `name` and the current updated attribute, it means that it's
    # an init process and the current object cannot be associated with the remote one without knowing its name
    if hasattr(self, "name") and hasattr(self, attr):
        try:  # Update the filter in database only if it exists remotely
            with ClientFactory.CURRENT.get()() as client:
                filter_id = self.id

                if attr in {"dataset", "min_index", "max_index"}:
                    message = (
                        f"Updating {attr} attribute is not allowed. Consider creating a new filter "
                        "instance with the new dataset value to achieve this"
                    )
                    raise AttributeError(message)

                update_filter.sync(
                    project_id=Project.CURRENT.get().model.id,
                    filter_id=filter_id,
                    client=client,
                    body=self._to_update(
                        new_name=value if attr == "name" else None,  # type: ignore[arg-type]
                        new_row_indexes=value if attr == "row_indexes" else None,  # type: ignore[arg-type]
                        new_criteria=value if attr == "criteria" else None,  # type: ignore[arg-type]
                    ),
                )
        except NotSavedError:
            pass

    object.__setattr__(self, attr, value)

add_criteria(*args: NumericalCriterion | CategoricalCriterion | MultivariateCriterion | TimeseriesBoxCriterion) -> None #

Add many criteria.

Source code in src/xpdeep/filtering/filter.py
def add_criteria(
    self, *args: NumericalCriterion | CategoricalCriterion | MultivariateCriterion | TimeseriesBoxCriterion
) -> None:
    """Add many criteria."""
    self.criteria += list(args)

save(*, force: bool = False) -> Filter #

Save the Filter remotely.

Source code in src/xpdeep/filtering/filter.py
@initialized_client_verification
@initialized_project_verification
def save(self, *, force: bool = False) -> "Filter":
    """Save the Filter remotely."""
    try:  # Try to insert the filter in DB
        handle_api_validation_errors(
            insert_filter.sync(
                project_id=Project.CURRENT.get().model.id,
                client=ClientFactory.CURRENT.get()(),
                body=self._to_insert(),
            ),
        )
    except UnexpectedStatus as err:
        # Filter with the same name already exists in DB.
        if not force:
            message = (
                f"The Filter: {self.name} already exists in database. Update `name` or use `force=True` in "
                f"order to create new remote object with a different name."
            )
            raise DuplicatedRemoteObjectError(message) from err

        self.name += f"_{datetime.now()}"  # noqa:DTZ005
        handle_api_validation_errors(
            insert_filter.sync(
                project_id=Project.CURRENT.get().model.id,
                client=ClientFactory.CURRENT.get()(),
                body=self._to_insert(),
            ),
        )

    return self

load_all() -> list[Filter] #

List all filters of the current project.

Source code in src/xpdeep/filtering/filter.py
@classmethod
@initialized_client_verification
@initialized_project_verification
def load_all(cls) -> list["Filter"]:
    """List all filters of the current project."""
    return [cls._from_select_one(select_one) for select_one in cls._load()]

get_by_id(filter_id: str) -> Filter #

Get Filter by its ID.

Parameters:

Name Type Description Default
filter_id #
str

The ID of the Filter to retrieve.

required
Source code in src/xpdeep/filtering/filter.py
@classmethod
@initialized_client_verification
@initialized_project_verification
def get_by_id(cls, filter_id: str) -> "Filter":
    """Get Filter by its ID.

    Parameters
    ----------
    filter_id : str
        The ID of the Filter to retrieve.
    """
    try:
        return cls._from_select_one(next(iter(cls._load(filter_id=filter_id))))
    except StopIteration as err:
        message = f"No Filter found remotely for the filter ID: {filter_id}"
        raise NotSavedError(message) from err

get_by_name(filter_name: str) -> Filter #

Get Filter by its name.

Parameters:

Name Type Description Default
filter_name #
str

The name of the Filter to retrieve.

required
Source code in src/xpdeep/filtering/filter.py
@classmethod
@initialized_client_verification
@initialized_project_verification
def get_by_name(cls, filter_name: str) -> "Filter":
    """Get Filter by its name.

    Parameters
    ----------
    filter_name : str
        The name of the Filter to retrieve.
    """
    try:
        return cls._from_select_one(next(iter(cls._load(filter_name=filter_name))))
    except StopIteration as err:
        message = f"No Filter found remotely for the filter name: {filter_name}"
        raise NotSavedError(message) from err

delete() -> None #

Delete the current object remotely.

Source code in src/xpdeep/filtering/filter.py
def delete(self) -> None:
    """Delete the current object remotely."""
    with ClientFactory.CURRENT.get()() as client:
        try:
            delete_one_filter.sync(
                Project.CURRENT.get().model.id,
                self.id,
                client=client,
            )
        except NotSavedError:
            message = "The current object does not exist remotely."

            warnings.warn(
                message,
                category=UserWarning,
                stacklevel=2,
            )

__len__() -> int #

Get filter's result size.

Ignores row_indexes.

Source code in src/xpdeep/filtering/filter.py
@initialized_client_verification
@initialized_project_verification
def __len__(self) -> int:
    """
    Get filter's result size.

    Ignores row_indexes.
    """
    try:
        dataset_id = self.dataset.id
    except NotSavedError as err:
        message = (
            "The Fitted Parquet Dataset associated with this filter was not found remotely. First save it "
            "using `FittedParquetDataset.save` method"
        )
        raise NotSavedError(message) from err

    try:
        filter_id = self.id
    except NotSavedError as err:
        message = "This filer is not applied yet. Call `save` method before computing filter's results size."
        raise NotSavedError(message) from err

    with ClientFactory.CURRENT.get()() as client:
        return handle_api_validation_errors(
            get_parquet_dataset_artifact_size.sync(
                project_id=Project.CURRENT.get().model.id,
                parquet_dataset_artifact_id=dataset_id,
                client=client,
                body=FilterIdRequestBody(
                    filter_id=filter_id,
                ),
            ),
        )