torch_model
Define the TorchModel, an interface to create a XpdeepModel from a torch nn module.
Classes:
Name | Description |
---|---|
TorchModel |
Class to represent a torch model using torch.fx serialization. |
TorchModel(exported_program: TorchSerializedArtifact, module: nn.Module)
#
Class to represent a torch model using torch.fx serialization.
Methods:
Name | Description |
---|---|
safe_serialize_torch_artifact |
Serialize a torch artifact to bytes with safetensors. |
from_torch_module |
Load torch from a xpdeep model, input size should not contain the batch dimension. |
from_torch_loss |
Load torch from an xpdeep model given the input and target size without batch size. |
from_exported |
From exported. |
Attributes:
Name | Type | Description |
---|---|---|
original_module |
|
Source code in src/xpdeep/model/torch_model.py
original_module = module
#
safe_serialize_torch_artifact(artifact: dict[str, torch.Tensor]) -> bytes
#
Serialize a torch artifact to bytes with safetensors.
A SafeTensor doesn't rely on pickle
but JSON and is therefore safe to deserialize.
See: https://huggingface.co/docs/safetensors/index
Parameters:
Name | Type | Description | Default |
---|---|---|---|
|
dict[str, Tensor]
|
The torch artifact to serialize. |
required |
Returns:
Type | Description |
---|---|
bytes
|
The serialized artifact. |
Source code in src/xpdeep/model/torch_model.py
from_torch_module(module: torch.nn.Module, input_size: int | tuple[int, ...]) -> TorchModel
#
Load torch from a xpdeep model, input size should not contain the batch dimension.
Source code in src/xpdeep/model/torch_model.py
from_torch_loss(module: _Loss, output_size: tuple[int, ...], target_size: tuple[int, ...]) -> TorchModel
#
Load torch from an xpdeep model given the input and target size without batch size.
Source code in src/xpdeep/model/torch_model.py
from_exported(exported_program: ExportedProgram, module: nn.Module) -> TorchModel
#
From exported.