Skip to content

augmentation

feature's utils.

Classes:

Name Description
FeatureAugmentation

Feature augmentation class, mainly for images.

Functions:

Name Description
from_select_transform_augmentation

Reverse of to_insert_transform_augmentation using **kwargs.

trusted_augmentation_from_model

Convert a serialized transform to a torchvision transform.

augmentation_from_model

Reverse of to_model().

FeatureAugmentation #

Feature augmentation class, mainly for images.

For images, the corresponding data uses the channel-last format, i.e. batch_size x H x W x num_channels. You may need to use Compose([Permute([0, 3, 1, 2]), YourTransformation(), Permute([0, 2, 3, 1])]) if your augmentation requires the channel first. You don't need to convert to torch tensor first, as it is automatically handled by xpdeep.

Parameters:

Name Type Description Default

augment_raw #

Compose | None

A feature augmentation function used to augment raw data. This is done before data preprocessing.

None

augment_preprocessed #

Compose | None

A feature augmentation function used to augment data after it has been preprocessed.

None

Methods:

Name Description
__attrs_post_init__

Post init method, triggered after calling init method to check instance's validity.

to_model

Convert to AugmentationInsert.

stable_hash

Return the hash.

Attributes:

Name Type Description
augment_raw AllAugmentationTransform | None
augment_preprocessed AllAugmentationTransform | None

augment_raw: AllAugmentationTransform | None = field(default=None) #

augment_preprocessed: AllAugmentationTransform | None = field(default=None) #

__attrs_post_init__() -> None #

Post init method, triggered after calling init method to check instance's validity.

Source code in src/xpdeep/dataset/feature/augmentation/augmentation.py
def __attrs_post_init__(self) -> None:
    """Post init method, triggered after calling __init__ method to check instance's validity."""
    if self.augment_raw is None and self.augment_preprocessed is None:
        message = "`augment_raw` and `augment_preprocessed` are both None, no augmentation will be applied."
        raise Warning(message)

to_model() -> AugmentationInsert #

Convert to AugmentationInsert.

Source code in src/xpdeep/dataset/feature/augmentation/augmentation.py
def to_model(self) -> AugmentationInsert:
    """Convert to AugmentationInsert."""
    raw_value = to_insert_augmentation(self.augment_raw) if self.augment_raw is not None else None

    preprocessed_value = (
        to_insert_augmentation(self.augment_preprocessed) if (self.augment_preprocessed is not None) else None
    )

    return AugmentationInsert(raw_value=raw_value, preprocessed_value=preprocessed_value)

stable_hash() -> str #

Return the hash.

Source code in src/xpdeep/dataset/feature/augmentation/augmentation.py
def stable_hash(self) -> str:
    """Return the hash."""
    return str(hashlib.sha256(f"{self.to_model().to_dict()}".encode()).hexdigest())

from_select_transform_augmentation(value: TransformAugmentationValue) -> AllAugmentationTransform #

Reverse of to_insert_transform_augmentation using **kwargs.

Source code in src/xpdeep/dataset/feature/augmentation/augmentation.py
def from_select_transform_augmentation(value: TransformAugmentationValue) -> AllAugmentationTransform:
    """Reverse of to_insert_transform_augmentation using **kwargs."""
    transform_type = value.type_
    kwargs = value.kwargs.to_dict() or {}

    cls_map = {
        # Basic transforms
        "ToTensor": ToTensor,
        "PILToTensor": PILToTensor,
        "ConvertImageDtype": ConvertImageDtype,
        "ToPILImage": ToPILImage,
        # Normalize
        "Normalize": Normalize,
        # Geometric
        "Resize": Resize,
        "CenterCrop": CenterCrop,
        "Pad": Pad,
        # Random crop / flips / resize-crop
        "RandomCrop": RandomCrop,
        "RandomHorizontalFlip": RandomHorizontalFlip,
        "RandomVerticalFlip": RandomVerticalFlip,
        "RandomResizedCrop": RandomResizedCrop,
        "FiveCrop": FiveCrop,
        "TenCrop": TenCrop,
        # Color
        "ColorJitter": ColorJitter,
        # Rotation / Affine / Perspective
        "RandomRotation": RandomRotation,
        "RandomAffine": RandomAffine,
        "RandomPerspective": RandomPerspective,
        # Grayscale
        "Grayscale": Grayscale,
        "RandomGrayscale": RandomGrayscale,
        # Erase / blur
        "RandomErasing": RandomErasing,
        "GaussianBlur": GaussianBlur,
        # Pixel-level transforms
        "RandomInvert": RandomInvert,
        "RandomPosterize": RandomPosterize,
        "RandomSolarize": RandomSolarize,
        "RandomAdjustSharpness": RandomAdjustSharpness,
        "RandomAutocontrast": RandomAutocontrast,
        "RandomEqualize": RandomEqualize,
        # Elastic
        "ElasticTransform": ElasticTransform,
        # Interpolation enum
        "InterpolationMode": InterpolationMode,
    }

    if transform_type not in cls_map:
        msg = f"Unknown transform type: {transform_type}"
        raise ApiError(msg)

    cls = cls_map[transform_type]

    # Special case: InterpolationMode is an enum; must call InterpolationMode(value)
    if cls is InterpolationMode:
        return InterpolationMode(kwargs["value"])

    # Specific cases json deserialization (if a list we keep a list it should work, no need to convert back to tuple)

    if cls in {ElasticTransform, Resize, RandomResizedCrop, RandomAffine, RandomPerspective, RandomRotation}:
        kwargs["interpolation"] = InterpolationMode[kwargs["interpolation"]]

    # For all other transforms, just unpack kwargs directly
    return cls(**kwargs)

trusted_augmentation_from_model(value: ComposeAugmentationValueInput | LinearTransformationAugmentationValue | RandomApplyAugmentationValueInput | RandomChoiceAugmentationValueInput | RandomOrderAugmentationValueInput | TransformAugmentationValue | None) -> AllAugmentationTransform | None #

Convert a serialized transform to a torchvision transform.

Source code in src/xpdeep/dataset/feature/augmentation/augmentation.py
def trusted_augmentation_from_model(  # noqa:PLR0911
    value: ComposeAugmentationValueInput
    | LinearTransformationAugmentationValue
    | RandomApplyAugmentationValueInput
    | RandomChoiceAugmentationValueInput
    | RandomOrderAugmentationValueInput
    | TransformAugmentationValue
    | None,
) -> AllAugmentationTransform | None:
    """Convert a serialized transform to a torchvision transform."""
    match value:
        case ComposeAugmentationValueInput():
            return Compose([trusted_augmentation_from_model(v) for v in value.transforms])

        # ---------------------------------------------
        # RandomApply
        # ---------------------------------------------
        case RandomApplyAugmentationValueInput():
            return RandomApply(transforms=[trusted_augmentation_from_model(v) for v in value.transforms], p=value.p)

        # ---------------------------------------------
        # RandomChoice
        # ---------------------------------------------
        case RandomChoiceAugmentationValueInput():
            return RandomChoice(transforms=[trusted_augmentation_from_model(v) for v in value.transforms], p=value.p)

        # ---------------------------------------------
        # RandomOrder
        # ---------------------------------------------
        case RandomOrderAugmentationValueInput():
            return RandomOrder(transforms=[trusted_augmentation_from_model(v) for v in value.transforms])

        # ---------------------------------------------
        # LinearTransformation
        # ---------------------------------------------
        case LinearTransformationAugmentationValue():
            tm = torch.from_numpy(value.transformation_matrix)
            mv = torch.from_numpy(value.mean_vector)
            return LinearTransformation(tm, mv)

        # ---------------------------------------------
        # Simple transform
        # ---------------------------------------------
        case TransformAugmentationValue():
            return from_select_transform_augmentation(value)

        case None:
            return None

        case _:
            assert_never(value)

augmentation_from_model(augmentation_select: AugmentationSelectInput | AugmentationInsert) -> FeatureAugmentation #

Reverse of to_model().

Source code in src/xpdeep/dataset/feature/augmentation/augmentation.py
def augmentation_from_model(augmentation_select: AugmentationSelectInput | AugmentationInsert) -> FeatureAugmentation:
    """Reverse of to_model()."""
    return FeatureAugmentation(
        augment_raw=trusted_augmentation_from_model(augmentation_select.raw_value)
        if not isinstance(augmentation_select.raw_value, Unset)
        else None,
        augment_preprocessed=trusted_augmentation_from_model(augmentation_select.preprocessed_value)
        if not isinstance(augmentation_select.preprocessed_value, Unset)
        else None,
    )