afnio.tellurio.websocket_client#

Classes

TellurioWebSocketClient([base_url, port, ...])

A WebSocket client for interacting with the Tellurio backend.

class afnio.tellurio.websocket_client.TellurioWebSocketClient(base_url=None, port=None, default_timeout=30)[source]#

Bases: object

A WebSocket client for interacting with the Tellurio backend.

This client establishes a WebSocket connection to the backend, sends requests, listens for responses, and handles reconnections. It supports JSON-RPC-style communication and is designed to work with asynchronous workflows.

async call(method, params, timeout=None)[source]#

Sends a request over the WebSocket connection and waits for a response.

Constructs a JSON-RPC request, sends it to the WebSocket server, and waits for the corresponding response. If no response is received within the timeout period, a TimeoutError is raised.

Parameters:
  • method (str) – The name of the method to call on the backend.

  • params (dict) – The parameters to pass to the method.

  • timeout (int, optional) – The timeout (in seconds) for the response. If not provided, the default timeout is used.

Returns:

The result of the method call.

Return type:

dict

Raises:
async close()[source]#

Closes the WebSocket connection and cleans up resources.

Cancels the listener task, clears pending requests, and closes the WebSocket connection.

async connect(api_key=None, retries=3, delay=5)[source]#

Connects to the WebSocket server with retry logic.

Attempts to establish a WebSocket connection to the backend. If the connection fails, it retries up to the specified number of attempts with a delay between each attempt.

Parameters:
  • api_key (str) – The API key for authenticating with the backend.

  • retries (int) – The number of reconnection attempts (default: 3).

  • delay (int) – The delay (in seconds) between reconnection attempts (default: 5).

Returns:

The session ID received from the server upon successful connection.

Return type:

str

Raises:

RuntimeError – If the connection fails after all retry attempts.

async rpc_append_grad(params)[source]#

Handle the ‘append_grad’ JSON-RPC method from the server.

This method appends a new gradient variable to the local grad list of the specified Variable instance. It is typically called when the server notifies the client that a new gradient has been added to a variable during the backward pass.

Parameters:

params (dict) –

A dictionary containing:

  • variable_id: The unique identifier of the Variable to update.

  • gradient: The serialized gradient Variable to append.

Returns:

A dictionary with a success message if the gradient is appended.

Return type:

dict

Raises:
  • KeyError – If required keys are missing from params.

  • RuntimeError – If appending the gradient fails for any reason.

async rpc_clear_backward(params)[source]#

Handle the ‘clear_backward’ JSON-RPC method from the server.

This method clears the _pending_grad flag for the specified variables. It is called after the server finalizes the backward pass for the entire computation graph, indicating that the gradients for its variables have been computed and already shared with the client. Once it receives ‘clear_backward’, the client can safely access the values of these gradients without worrying about them being modified.

Parameters:

params (dict) –

A dictionary containing:

  • variable_ids: A list of variable IDs for which to clear

    the _pending_grad flag.

Raises:
  • KeyError – If required keys are missing from params.

  • RuntimeError – If clearing the pending grad fails for any variable.

async rpc_clear_step(params)[source]#

Handle the ‘clear_step’ JSON-RPC method from the server.

This method clears the _pending_data flag for the specified variables. It is called after the server completes an optimizer step and updates the data for the relevant variables. Once ‘clear_step’ is received, the client can safely access the updated values of these variables, knowing that the data is no longer pending or being modified.

Parameters:

params (dict) –

A dictionary containing:

  • variable_ids: A list of variable IDs (str) for which to clear the _pending_data flag.

Raises:
  • KeyError – If required keys are missing from params.

  • RuntimeError – If clearing the pending data fails for any variable.

async rpc_create_edge(params)[source]#

Handle the ‘create_edge’ JSON-RPC method from the server.

This method creates a GradientEdge between two nodes in the local registry, appending the edge to the from_node’s next_functions. It is typically called when the server notifies the client that a new edge has been created in the computation graph.

Note

The terms ‘from_node’ and ‘to_node’ should be interpreted in the context of the backward pass (backpropagation): the edge is added to the from_node’s next_functions and points to the to_node, following the direction of gradient flow during backpropagation.

Parameters:

params (dict) –

A dictionary with keys:

  • from_node_id: The unique identifier of the source node.

  • to_node_id: The unique identifier of the destination node.

  • output_nr: The output number associated with the edge.

Returns:

A dictionary with a success message if the edge is created.

Return type:

dict

Raises:

KeyError – If required keys are missing from params.

async rpc_create_node(params)[source]#

Handle the ‘create_node’ JSON-RPC method from the server.

This method creates and registers a new Node instance in the local registry using the provided parameters. It is typically called when the server notifies the client that a new node has been created in the computation graph.

Parameters:

params (dict) –

A dictionary with keys:

  • node_id: The unique identifier of the Node.

  • name: The class name or type of the Node.

Returns:

A dictionary with a success message if the node is created.

Return type:

dict

Raises:

KeyError – If required keys are missing from params.

async rpc_create_variable(params)[source]#

Handle the ‘create_variable’ JSON-RPC method from the server.

This method creates and registers a new Variable instance in the local registry using the provided parameters. It is typically called when the server creates a deepcopy of a Variable or Parameter and needs to notify the client.

Parameters:

params (dict) –

A dictionary with keys:

  • variable_id: The unique identifier of the Variable.

  • obj_type: The type of the variable object

    (e.g., “__variable__” or “__parameter__”).

  • data: The initial data for the variable.

  • role: The role or description of the variable.

  • requires_grad: Whether the variable requires gradients.

  • _retain_grad: Whether to retain gradients for non-leaf variables.

  • _grad: The initial gradient(s) for the variable.

  • _output_nr: The output number for the variable in the computation

    graph.

  • _grad_fn: The gradient function associated with the variable.

  • is_leaf: Whether the variable is a leaf node in the computation

    graph.

Returns:

A dictionary with a success message if the variable is created.

Return type:

dict

Raises:
  • KeyError – If required keys are missing from params.

  • RuntimeError – If the variable creation fails for any reason.

async rpc_heartbeat(params)[source]#

Handle the ‘heartbeat’ JSON-RPC notification from the server.

This method is called when the server sends a heartbeat notification for a long-running operation. It updates the last heartbeat timestamp for the corresponding request ID, allowing the client to reset its timeout and avoid prematurely timing out while the server is still processing the request.

Parameters:

params (dict) –

A dictionary with keys:

  • id: The request ID (str) associated with the long-running operation.

async rpc_run_callable(params)[source]#

Handle the ‘run_callable’ JSON-RPC method from the server.

This method is invoked when the server sends a JSON-RPC request with the method “run_callable”. It extracts callable details from the provided parameters, executes the callable from the registry, and returns a response containing the result. The response is expected to be JSON-serializable.

Parameters:

params (dict) –

A dictionary containing:

  • callable_id: A unique identifier for the callable.

  • args: Positional arguments (as a list or tuple) for the callable.

  • kwargs: Keyword arguments for the callable.

Returns:

A dictionary with the following structure:

{

“message”: “Ok”, “data”: <result of executing the callable>

}

Return type:

dict

Raises:
  • KeyError – If required keys are missing in params.

  • TypeError – If the result of the callable is not JSON-serializable.

  • ValueError – If the callable execution fails due to invalid parameters.

  • RuntimeError – For any other exception encountered during callable execution.

async rpc_update_model(params)[source]#

Handle the ‘update_model’ JSON-RPC method from the server.

This method updates a specific field of a registered LM model instance in response to a server notification. It uses the provided parameters to identify the LM model and the field to update.

Parameters:

params (dict) –

A dictionary with keys:

  • model_id: The unique identifier of the LM model.

  • field: The field name to update.

  • value: The new value to set for the field.

Returns:

A dictionary with a success message if the update is successful.

Return type:

dict

Raises:
  • KeyError – If required keys are missing from params.

  • RuntimeError – If the update fails for any reason.

async rpc_update_variable(params)[source]#

Handle the ‘update_variable’ JSON-RPC method from the server.

This method updates a specific field of a registered Variable instance in response to a server notification. It uses the provided parameters to identify the variable and the field to update.

Parameters:

params (dict) –

A dictionary with keys:

  • variable_id: The unique identifier of the Variable.

  • field: The field name to update.

  • value: The new value to set for the field.

Returns:

A dictionary with a success message if the update is successful.

Return type:

dict

Raises:
  • KeyError – If required keys are missing from params.

  • RuntimeError – If the update fails for any reason.