Skip to content

model_training

Model analysis pipelines utilities.

Functions:

Name Description
get_pipeline_result

Get training pipeline results.

get_pipeline_result(pipeline_id: str) -> TrainedModelArtifact #

Get training pipeline results.

Source code in src/xpdeep/utils/pipelines/model_training.py
@initialized_client_verification
@initialized_project_verification
@retry_on_exception((httpx.RemoteProtocolError, urllib3.exceptions.ProtocolError), max_retries=10)
def get_pipeline_result(pipeline_id: str) -> TrainedModelArtifact:
    """Get training pipeline results."""
    with (
        ClientFactory.CURRENT.get()() as client,
        connect_sse(
            client.get_httpx_client(), "GET", f"/{Project.CURRENT.get().model.id}/pipeline/{pipeline_id}/progress"
        ) as event_source,
    ):
        try:
            for event in event_source.iter_sse():
                event_as_dict = json.loads(event.data)

                if "console_output" in event_as_dict:
                    print(event_as_dict["console_output"], end="")  # noqa : T201
                elif "error" in event_as_dict:
                    raise ApiError(event_as_dict["error"])
                else:
                    return TrainedModelArtifact.from_dict(event_as_dict)
        except KeyboardInterrupt:
            # TODO(<meziane bellahmer>): verify pipeline cancelling. Will be fixed with new client generator library. https://gitlab.xpdeep.com/xpdeep/xpdeep-client/-/issues/279
            handle_api_validation_errors(
                cancel_pipeline.sync(Project.CURRENT.get().model.id, pipeline_id, client=ClientFactory.CURRENT.get()())
            )
            raise

        msg = f"Unexpected error during pipeline `{pipeline_id}` execution"
        raise ApiError(msg)