Skip to content

torch_model

Define the TorchModel, an interface to create a XpdeepModel from a torch nn module.

Classes:

Name Description
TorchModel

Class to represent a torch model using torch.fx serialization.

TorchModel(exported_program: TorchSerializedArtifact, module: nn.Module) #

Class to represent a torch model using torch.fx serialization.

Methods:

Name Description
safe_serialize_torch_artifact

Serialize a torch artifact to bytes with safetensors.

from_torch_module

Load torch from a xpdeep model, input size should not contain the batch dimension.

from_torch_loss

Load torch from an xpdeep model given the input and target size without batch size.

from_exported

From exported.

Attributes:

Name Type Description
original_module
Source code in src/xpdeep/model/torch_model.py
def __init__(self, exported_program: TorchSerializedArtifact, module: nn.Module) -> None:
    self.original_module = module
    super().__init__(
        exported_program=SerializedArtifactRequestBody(
            exported_program=b64encode(exported_program.exported_program).decode("utf-8"),
            state_dict=b64encode(exported_program.state_dict).decode("utf-8"),
            constants=b64encode(exported_program.constants).decode("utf-8"),
            example_inputs=b64encode(exported_program.example_inputs).decode("utf-8"),
        )
    )

original_module = module #

safe_serialize_torch_artifact(artifact: dict[str, torch.Tensor]) -> bytes #

Serialize a torch artifact to bytes with safetensors.

A SafeTensor doesn't rely on pickle but JSON and is therefore safe to deserialize. See: https://huggingface.co/docs/safetensors/index

Parameters:

Name Type Description Default
artifact #
dict[str, Tensor]

The torch artifact to serialize.

required

Returns:

Type Description
bytes

The serialized artifact.

Source code in src/xpdeep/model/torch_model.py
@staticmethod
def safe_serialize_torch_artifact(artifact: dict[str, torch.Tensor]) -> bytes:
    """
    Serialize a torch artifact to bytes with safetensors.

    A SafeTensor doesn't rely on `pickle` but JSON and is therefore safe to deserialize.
    See: https://huggingface.co/docs/safetensors/index

    Parameters
    ----------
    artifact : dict[str, torch.Tensor]
        The torch artifact to serialize.

    Returns
    -------
    bytes
        The serialized artifact.
    """
    for key in list(artifact.keys()):
        if isinstance(artifact[key], torch.nn.Parameter):
            # add a special tensor to indicate that this tensor is a parameter
            artifact[key + "_PARAM"] = torch.Tensor([True])
    return save(artifact)

from_torch_module(module: torch.nn.Module, input_size: int | tuple[int, ...]) -> TorchModel #

Load torch from a xpdeep model, input size should not contain the batch dimension.

Source code in src/xpdeep/model/torch_model.py
@staticmethod
def from_torch_module(module: torch.nn.Module, input_size: int | tuple[int, ...]) -> TorchModel:
    """Load torch from a xpdeep model, input size should not contain the batch dimension."""
    if isinstance(input_size, int):
        input_size = (input_size,)
    # Create a dynamic batch size
    # Here batch size cannot be "1" as there is a bug in pytorch.
    batch = Dim("batch", min=2, max=None)

    # Specify that the first dimension of each input is that batch size
    dynamic_shapes = ({0: batch},)

    # see https://pytorch.org/docs/stable/export.html for `strict`:
    # In strict mode, which is currently the default, we first trace through the program using TorchDynamo, a
    # bytecode analysis engine. TorchDynamo does not execute your Python code. Instead, it symbolically
    # analyzes it and builds a graph based on the results. This analysis allows torch.export to provide stronger
    # guarantees about safety, but not all Python code is supported.

    exported = export(module, args=(torch.randn((128, *input_size)),), strict=True, dynamic_shapes=dynamic_shapes)
    return TorchModel.from_exported(exported, module)

from_torch_loss(module: _Loss, output_size: tuple[int, ...], target_size: tuple[int, ...]) -> TorchModel #

Load torch from an xpdeep model given the input and target size without batch size.

Source code in src/xpdeep/model/torch_model.py
@staticmethod
def from_torch_loss(module: _Loss, output_size: tuple[int, ...], target_size: tuple[int, ...]) -> TorchModel:
    """Load torch from an xpdeep model given the input and target size without batch size."""
    # Create a dynamic batch size
    # Here batch size cannot be "1" as there is a bug in pytorch.
    batch = Dim("batch", min=2, max=None)

    # Specify that the first dimension of each input/target is that batch size
    dynamic_shapes = [{0: batch}, {0: batch}]  # wrong typing in torch, tuple accepted but not in mypy

    # see https://pytorch.org/docs/stable/export.html for `strict`:
    # In strict mode, which is currently the default, we first trace through the program using TorchDynamo, a
    # bytecode analysis engine. TorchDynamo does not execute your Python code. Instead, it symbolically
    # analyzes it and builds a graph based on the results. This analysis allows torch.export to provide stronger
    # guarantees about safety, but not all Python code is supported.
    exported = export(
        module,
        args=(torch.randn(*(128, *output_size)), torch.randn(*(128, *target_size))),
        strict=True,
        dynamic_shapes=dynamic_shapes,
    )
    return TorchModel.from_exported(exported, module=module)

from_exported(exported_program: ExportedProgram, module: nn.Module) -> TorchModel #

From exported.

Source code in src/xpdeep/model/torch_model.py
@staticmethod
def from_exported(exported_program: ExportedProgram, module: nn.Module) -> TorchModel:
    """From exported."""
    # serialized_artifact = ExportedProgramSerializer(None).serialize(
    # exported_program
    # )
    # here we use ``serialize`` and not ``ExportedProgramSerializer``
    # because torch.Export does not handle attrs objects with non-optional attributes,
    # but that sometimes got
    # ``None`` values, so we use the "slow" torch encoder
    serialized_artifact = serialize(exported_program)
    # first reload in torch
    with warnings.catch_warnings():  # torch bug causing warnings when deserialization.
        warnings.simplefilter(action="ignore", category=FutureWarning)
        # Remove the filter in torch 2.5
        state_dict = deserialize_torch_artifact(serialized_artifact.state_dict)
        constants = deserialize_torch_artifact(serialized_artifact.constants)

    serialized_artifact.state_dict = TorchModel.safe_serialize_torch_artifact(state_dict)
    serialized_artifact.constants = TorchModel.safe_serialize_torch_artifact(constants)

    return TorchModel(exported_program=serialized_artifact, module=module)