Skip to content

utils

Convert methods.

Functions:

Name Description
convert_payload

Convert ModuleState from xpdeep_types to the database model.

convert_serialized_module

Convert from xpdeep-types to xpdeep-api-client type.

convert_payload(payload: ModuleState | TorchArtifact) -> ModuleStateDatabase | TorchArtifactDatabase #

convert_payload(
    payload: TorchArtifact,
) -> TorchArtifactDatabase
convert_payload(
    payload: ModuleState,
) -> ModuleStateDatabase

Convert ModuleState from xpdeep_types to the database model.

Source code in src/xpdeep/model/utils.py
def convert_payload(payload: ModuleState | TorchArtifact) -> ModuleStateDatabase | TorchArtifactDatabase:
    """Convert ModuleState from xpdeep_types to the database model."""
    if isinstance(payload, TorchArtifact):
        return TorchArtifactDatabase(
            exported_program=b64encode(payload.exported_program).decode("utf-8"),
            state_dict=b64encode(payload.state_dict).decode("utf-8"),
            constants=b64encode(payload.constants).decode("utf-8"),
            example_inputs=b64encode(payload.example_inputs).decode("utf-8")
            if payload.example_inputs is not None
            else None,
        )
    if isinstance(payload, ModuleState):
        return ModuleStateDatabase(
            module=payload.module_,
            class_=payload.class_,
            state_dict=b64encode(payload.state_dict).decode("utf-8"),  # bytes to str
            init_params=ModuleStateInitParams.from_dict(payload.init_params),
            use_safetensors=payload.use_safetensors,
        )
    msg = f"Cannot convert {type(payload)} to xpdeep-database type."  # type: ignore[unreachable]
    raise ModuleSerializationError(msg)

convert_serialized_module(module: SerializedModule) -> SerializedModuleInput #

Convert from xpdeep-types to xpdeep-api-client type.

Source code in src/xpdeep/model/utils.py
def convert_serialized_module(module: SerializedModule) -> SerializedModuleInput:
    """Convert from xpdeep-types to xpdeep-api-client type."""
    if isinstance(module.payload, PickledModule):
        msg = f"Cannot convert {type(module.payload)} to xpdeep-database type."
        raise ModuleSerializationError(msg)
    return SerializedModuleInput(
        module_version=module.module_version,
        package=module.package,
        package_version=module.package_version,
        payload=convert_payload(module.payload),
    )