Skip to content

loss

Loss wrapper.

Classes:

Name Description
XpdeepLoss

Xpdeep loss wrapper.

XpdeepLoss #

Xpdeep loss wrapper.

loss: AbstractLossCriterion The custom loss (ApiLoss, exported as torchArtifact) or a trusted Xpdeep custom loss (exported as ModuleState).

Parameters:

Name Type Description Default

loss #

AbstractLossCriterion
required

Methods:

Name Description
from_torch

Build a loss object - infer required size from the model and fitted_schema.

to_model

Convert the loss object to a database model input.

Attributes:

Name Type Description
loss AbstractLossCriterion

loss: AbstractLossCriterion #

from_torch(fitted_schema: FittedSchema, model: XpdeepModel, loss: nn.Module) -> Self #

Build a loss object - infer required size from the model and fitted_schema.

Source code in src/xpdeep/trainer/loss.py
@classmethod
def from_torch(cls, fitted_schema: FittedSchema, model: XpdeepModel, loss: nn.Module) -> Self:
    """Build a loss object - infer required size from the model and fitted_schema."""
    output_size = model.get_output_size(fitted_schema)

    inputs = torch.randn(*(5, *fitted_schema.target_size[1:]))
    targets = torch.randn(*(5, *output_size[1:]))
    api_loss = ApiLoss.from_torch_loss(loss=loss, predictions=inputs, targets=targets)
    return cls(loss=api_loss)

to_model() -> SerializedModuleInput #

Convert the loss object to a database model input.

Source code in src/xpdeep/trainer/loss.py
def to_model(self) -> SerializedModuleInput:
    """Convert the loss object to a database model input."""
    return convert_serialized_module(self.loss.to_pydantic())