Skip to content

utils_serialization

Utility methods for serialization, the deserialization is not implemented in client-side.

Type Aliases:

Name Description
AllAugmentationTransform

Classes:

Name Description
InvalidTransformError

Raised when the transform cannot be converted to the database modele.

Functions:

Name Description
make_jsonable

Convert kwargs to make jsonable data.

to_insert_transform_augmentation

Convert a torchvision transform TransformAugmentationTypes to insert value.

to_insert_augmentation

Convert an augmentation to its insert model value.

Attributes:

Name Type Description
TransformAugmentationTypes

TransformAugmentationTypes = (CenterCrop, ColorJitter, ConvertImageDtype, ElasticTransform, FiveCrop, GaussianBlur, Grayscale, InterpolationMode, Normalize, Pad, PILToTensor, RandomAdjustSharpness, RandomAffine, RandomAutocontrast, RandomCrop, RandomErasing, RandomEqualize, RandomGrayscale, RandomHorizontalFlip, RandomInvert, RandomPerspective, RandomPosterize, RandomResizedCrop, RandomRotation, RandomSolarize, RandomVerticalFlip, Resize, TenCrop, ToPILImage, ToTensor) #

AllAugmentationTransform = CenterCrop | ColorJitter | Compose | ConvertImageDtype | ElasticTransform | FiveCrop | GaussianBlur | Grayscale | InterpolationMode | LinearTransformation | Normalize | Pad | PILToTensor | RandomAdjustSharpness | RandomAffine | RandomApply | RandomAutocontrast | RandomChoice | RandomCrop | RandomErasing | RandomEqualize | RandomGrayscale | RandomHorizontalFlip | RandomInvert | RandomOrder | RandomPerspective | RandomPosterize | RandomResizedCrop | RandomRotation | RandomSolarize | RandomVerticalFlip | Resize | TenCrop | ToPILImage | ToTensor #

InvalidTransformError #

Raised when the transform cannot be converted to the database modele.

make_jsonable(kwargs: dict[str, object]) -> TransformAugmentationValueKwargs #

Convert kwargs to make jsonable data.

Source code in src/xpdeep/dataset/feature/augmentation/utils_serialization.py
def make_jsonable(kwargs: dict[str, object]) -> TransformAugmentationValueKwargs:
    """Convert kwargs to make jsonable data."""
    for key, value in kwargs.items():
        if isinstance(value, tuple):
            kwargs[key] = list(value)
        if isinstance(value, InterpolationMode):
            kwargs[key] = value.name
        if isinstance(value, torch.dtype):
            kwargs[key] = str(value)
    return TransformAugmentationValueKwargs.from_dict(kwargs)

to_insert_transform_augmentation(transform: AllAugmentationTransform) -> TransformAugmentationValue #

Convert a torchvision transform TransformAugmentationTypes to insert value.

Source code in src/xpdeep/dataset/feature/augmentation/utils_serialization.py
def to_insert_transform_augmentation(  # noqa: PLR0912, PLR0911, C901
    transform: AllAugmentationTransform,
) -> TransformAugmentationValue:
    """Convert a torchvision transform TransformAugmentationTypes to insert value."""

    def make_augmentation_value(type_name: TransformType, kwargs: dict[str, object]) -> TransformAugmentationValue:
        """Convert all kwargs using the helper before creating the database model."""
        return TransformAugmentationValue(type_=type_name, kwargs=make_jsonable(kwargs))

    match transform:
        # ------------------------------------------------------------------
        # Basic conversion
        # ------------------------------------------------------------------
        case ToTensor():
            return make_augmentation_value(TransformType.TOTENSOR, {})

        case PILToTensor():
            return make_augmentation_value(TransformType.PILTOTENSOR, {})

        case ConvertImageDtype():
            return make_augmentation_value(
                TransformType.CONVERTIMAGEDTYPE,
                {"dtype": transform.dtype},
            )

        case ToPILImage():
            return make_augmentation_value(
                TransformType.TOPILIMAGE,
                {"mode": transform.mode},
            )

        # ------------------------------------------------------------------
        # Normalize
        # ------------------------------------------------------------------
        case Normalize():
            return make_augmentation_value(
                TransformType.NORMALIZE,
                {
                    "mean": transform.mean,
                    "std": transform.std,
                    "inplace": transform.inplace,
                },
            )

        # ------------------------------------------------------------------
        # Geometric
        # ------------------------------------------------------------------
        case Resize():
            return make_augmentation_value(
                TransformType.RESIZE,
                {
                    "size": transform.size,
                    "interpolation": transform.interpolation,
                    "max_size": transform.max_size,
                    "antialias": transform.antialias,
                },
            )

        case CenterCrop():
            return make_augmentation_value(
                TransformType.CENTERCROP,
                {"size": transform.size},
            )

        case Pad():
            return make_augmentation_value(
                TransformType.PAD,
                {
                    "padding": transform.padding,
                    "fill": transform.fill,
                    "padding_mode": transform.padding_mode,
                },
            )

        # ------------------------------------------------------------------
        # Random crop / flip etc.
        # ------------------------------------------------------------------
        case RandomCrop():
            return make_augmentation_value(
                TransformType.RANDOMCROP,
                {
                    "size": transform.size,
                    "padding": transform.padding,
                    "pad_if_needed": transform.pad_if_needed,
                    "fill": transform.fill,
                    "padding_mode": transform.padding_mode,
                },
            )

        case RandomHorizontalFlip():
            return make_augmentation_value(
                TransformType.RANDOMHORIZONTALFLIP,
                {"p": transform.p},
            )

        case RandomVerticalFlip():
            return make_augmentation_value(
                TransformType.RANDOMVERTICALFLIP,
                {"p": transform.p},
            )

        case RandomResizedCrop():
            return make_augmentation_value(
                TransformType.RANDOMRESIZEDCROP,
                {
                    "size": transform.size,
                    "scale": transform.scale,
                    "ratio": transform.ratio,
                    "interpolation": transform.interpolation,
                    "antialias": transform.antialias,
                },
            )

        case FiveCrop():
            return make_augmentation_value(
                TransformType.FIVECROP,
                {"size": transform.size},
            )

        case TenCrop():
            return make_augmentation_value(
                TransformType.TENCROP,
                {
                    "size": transform.size,
                    "vertical_flip": transform.vertical_flip,
                },
            )

        # ------------------------------------------------------------------
        # Color Jitter
        # ------------------------------------------------------------------
        case ColorJitter():
            return make_augmentation_value(
                TransformType.COLORJITTER,
                {
                    "brightness": transform.brightness,
                    "contrast": transform.contrast,
                    "saturation": transform.saturation,
                    "hue": transform.hue,
                },
            )

        # ------------------------------------------------------------------
        # Rotation / affine / perspective
        # ------------------------------------------------------------------
        case RandomRotation():
            return make_augmentation_value(
                TransformType.RANDOMROTATION,
                {
                    "degrees": transform.degrees,
                    "interpolation": transform.interpolation,
                    "expand": transform.expand,
                    "center": transform.center,
                    "fill": transform.fill,
                },
            )

        case RandomAffine():
            return make_augmentation_value(
                TransformType.RANDOMAFFINE,
                {
                    "degrees": transform.degrees,
                    "translate": transform.translate,
                    "scale": transform.scale,
                    "shear": transform.shear,
                    "interpolation": transform.interpolation,
                    "fill": transform.fill,
                    "center": transform.center,
                },
            )

        case RandomPerspective():
            return make_augmentation_value(
                TransformType.RANDOMPERSPECTIVE,
                {
                    "distortion_scale": transform.distortion_scale,
                    "p": transform.p,
                    "interpolation": transform.interpolation,
                    "fill": transform.fill,
                },
            )

        # ------------------------------------------------------------------
        # Grayscale
        # ------------------------------------------------------------------
        case Grayscale():
            return make_augmentation_value(
                TransformType.GRAYSCALE,
                {"num_output_channels": transform.num_output_channels},
            )

        case RandomGrayscale():
            return make_augmentation_value(
                TransformType.RANDOMGRAYSCALE,
                {"p": transform.p},
            )

        # ------------------------------------------------------------------
        # Random erase, blur
        # ------------------------------------------------------------------
        case RandomErasing():
            return make_augmentation_value(
                TransformType.RANDOMERASING,
                {
                    "p": transform.p,
                    "scale": transform.scale,
                    "ratio": transform.ratio,
                    "value": transform.value,
                    "inplace": transform.inplace,
                },
            )

        case GaussianBlur():
            return make_augmentation_value(
                TransformType.GAUSSIANBLUR,
                {
                    "kernel_size": transform.kernel_size,
                    "sigma": transform.sigma,
                },
            )

        # ------------------------------------------------------------------
        # Pixel-level
        # ------------------------------------------------------------------
        case RandomInvert():
            return make_augmentation_value(
                TransformType.RANDOMINVERT,
                {"p": transform.p},
            )

        case RandomPosterize():
            return make_augmentation_value(
                TransformType.RANDOMPOSTERIZE,
                {"bits": transform.bits, "p": transform.p},
            )

        case RandomSolarize():
            return make_augmentation_value(
                TransformType.RANDOMSOLARIZE,
                {"threshold": transform.threshold, "p": transform.p},
            )

        case RandomAdjustSharpness():
            return make_augmentation_value(
                TransformType.RANDOMADJUSTSHARPNESS,
                {"sharpness_factor": transform.sharpness_factor, "p": transform.p},
            )

        case RandomAutocontrast():
            return make_augmentation_value(
                TransformType.RANDOMAUTOCONTRAST,
                {"p": transform.p},
            )

        case RandomEqualize():
            return make_augmentation_value(
                TransformType.RANDOMEQUALIZE,
                {"p": transform.p},
            )

        # ------------------------------------------------------------------
        # Elastic
        # ------------------------------------------------------------------
        case ElasticTransform():
            return make_augmentation_value(
                TransformType.ELASTICTRANSFORM,
                {
                    "alpha": transform.alpha,
                    "sigma": transform.sigma,
                    "interpolation": transform.interpolation,
                    "fill": transform.fill,
                },
            )

        # ------------------------------------------------------------------
        # InterpolationMode enum
        # ------------------------------------------------------------------
        case InterpolationMode():
            return make_augmentation_value(
                TransformType.INTERPOLATIONMODE,
                {"value": transform.value},
            )

        case _:
            msg = f"Cannot convert the augmentation {transform} to database model."
            raise InvalidTransformError(msg)

to_insert_augmentation(data: AllAugmentationTransform) -> ComposeAugmentationValueInput | LinearTransformationAugmentationValue | RandomApplyAugmentationValueInput | RandomChoiceAugmentationValueInput | RandomOrderAugmentationValueInput | TransformAugmentationValue | None #

Convert an augmentation to its insert model value.

Source code in src/xpdeep/dataset/feature/augmentation/utils_serialization.py
def to_insert_augmentation(
    data: AllAugmentationTransform,
) -> (
    ComposeAugmentationValueInput
    | LinearTransformationAugmentationValue
    | RandomApplyAugmentationValueInput
    | RandomChoiceAugmentationValueInput
    | RandomOrderAugmentationValueInput
    | TransformAugmentationValue
    | None
):
    """Convert an augmentation to its insert model value."""
    match data:
        case Compose():
            augmentation_value = ComposeAugmentationValueInput(
                type_="Compose",
                transforms=[
                    to_insert_transform_augmentation(transform=single_transform) for single_transform in data.transforms
                ],
            )
        case RandomApply():
            augmentation_value = RandomApplyAugmentationValueInput(  # type: ignore[assignment]
                type_="RandomApply",
                p=data.p,
                transforms=[
                    to_insert_transform_augmentation(transform=single_transform) for single_transform in data.transforms
                ],
            )
        case RandomChoice():
            augmentation_value = RandomChoiceAugmentationValueInput(  # type: ignore[assignment]
                type_="RandomChoice",
                p=data.p,
                transforms=[
                    to_insert_transform_augmentation(transform=single_transform) for single_transform in data.transforms
                ],
            )
        case RandomOrder():
            augmentation_value = RandomOrderAugmentationValueInput(  # type: ignore[assignment]
                type_="RandomOrder",
                transforms=[
                    to_insert_transform_augmentation(transform=single_transform) for single_transform in data.transforms
                ],
            )

        case LinearTransformation():
            augmentation_value = LinearTransformationAugmentationValue(  # type: ignore[assignment]
                type_="LinearTransformation",
                transformation_matrix=data.transformation_matrix.cpu().numpy(),
                mean_vector=data.mean_vector.cpu().numpy(),
            )
        case t if isinstance(t, TransformAugmentationTypes):
            augmentation_value = to_insert_transform_augmentation(transform=data)  # type: ignore[assignment]
        case _:
            msg = f"Cannot convert the augmentation {data} to insert model."
            raise InvalidTransformError(msg)
    return augmentation_value