afnio.tellurio.websocket_client#
Classes
|
A WebSocket client for interacting with the Tellurio backend. |
- class afnio.tellurio.websocket_client.TellurioWebSocketClient(base_url=None, port=None, default_timeout=30)[source]#
Bases:
objectA WebSocket client for interacting with the Tellurio backend.
This client establishes a WebSocket connection to the backend, sends requests, listens for responses, and handles reconnections. It supports JSON-RPC-style communication and is designed to work with asynchronous workflows.
- async call(method, params, timeout=None)[source]#
Sends a request over the WebSocket connection and waits for a response.
Constructs a JSON-RPC request, sends it to the WebSocket server, and waits for the corresponding response. If no response is received within the timeout period, a
TimeoutErroris raised.- Parameters:
- Returns:
The result of the method call.
- Return type:
- Raises:
RuntimeError – If the WebSocket connection is not established.
asyncio.TimeoutError – If the response is not received within the timeout period.
- async close()[source]#
Closes the WebSocket connection and cleans up resources.
Cancels the listener task, clears pending requests, and closes the WebSocket connection.
- async connect(api_key=None, retries=3, delay=5)[source]#
Connects to the WebSocket server with retry logic.
Attempts to establish a WebSocket connection to the backend. If the connection fails, it retries up to the specified number of attempts with a delay between each attempt.
- Parameters:
- Returns:
The session ID received from the server upon successful connection.
- Return type:
- Raises:
RuntimeError – If the connection fails after all retry attempts.
- async rpc_append_grad(params)[source]#
Handle the ‘append_grad’ JSON-RPC method from the server.
This method appends a new gradient variable to the local grad list of the specified Variable instance. It is typically called when the server notifies the client that a new gradient has been added to a variable during the backward pass.
- Parameters:
params (dict) –
A dictionary containing:
variable_id: The unique identifier of the Variable to update.gradient: The serialized gradient Variable to append.
- Returns:
A dictionary with a success message if the gradient is appended.
- Return type:
- Raises:
KeyError – If required keys are missing from params.
RuntimeError – If appending the gradient fails for any reason.
- async rpc_clear_backward(params)[source]#
Handle the ‘clear_backward’ JSON-RPC method from the server.
This method clears the
_pending_gradflag for the specified variables. It is called after the server finalizes the backward pass for the entire computation graph, indicating that the gradients for its variables have been computed and already shared with the client. Once it receives ‘clear_backward’, the client can safely access the values of these gradients without worrying about them being modified.- Parameters:
params (dict) –
A dictionary containing:
variable_ids: A list of variable IDs for which to clearthe
_pending_gradflag.
- Raises:
KeyError – If required keys are missing from params.
RuntimeError – If clearing the pending grad fails for any variable.
- async rpc_clear_step(params)[source]#
Handle the ‘clear_step’ JSON-RPC method from the server.
This method clears the
_pending_dataflag for the specified variables. It is called after the server completes an optimizer step and updates the data for the relevant variables. Once ‘clear_step’ is received, the client can safely access the updated values of these variables, knowing that the data is no longer pending or being modified.- Parameters:
params (dict) –
A dictionary containing:
variable_ids: A list of variable IDs (str) for which to clear the_pending_dataflag.
- Raises:
KeyError – If required keys are missing from params.
RuntimeError – If clearing the pending data fails for any variable.
- async rpc_create_edge(params)[source]#
Handle the ‘create_edge’ JSON-RPC method from the server.
This method creates a GradientEdge between two nodes in the local registry, appending the edge to the from_node’s next_functions. It is typically called when the server notifies the client that a new edge has been created in the computation graph.
Note
The terms ‘from_node’ and ‘to_node’ should be interpreted in the context of the backward pass (backpropagation): the edge is added to the from_node’s next_functions and points to the to_node, following the direction of gradient flow during backpropagation.
- Parameters:
params (dict) –
A dictionary with keys:
from_node_id: The unique identifier of the source node.to_node_id: The unique identifier of the destination node.output_nr: The output number associated with the edge.
- Returns:
A dictionary with a success message if the edge is created.
- Return type:
- Raises:
KeyError – If required keys are missing from params.
- async rpc_create_node(params)[source]#
Handle the ‘create_node’ JSON-RPC method from the server.
This method creates and registers a new Node instance in the local registry using the provided parameters. It is typically called when the server notifies the client that a new node has been created in the computation graph.
- async rpc_create_variable(params)[source]#
Handle the ‘create_variable’ JSON-RPC method from the server.
This method creates and registers a new Variable instance in the local registry using the provided parameters. It is typically called when the server creates a deepcopy of a Variable or Parameter and needs to notify the client.
- Parameters:
params (dict) –
A dictionary with keys:
variable_id: The unique identifier of the Variable.obj_type: The type of the variable object(e.g., “__variable__” or “__parameter__”).
data: The initial data for the variable.role: The role or description of the variable.requires_grad: Whether the variable requires gradients._retain_grad: Whether to retain gradients for non-leaf variables._grad: The initial gradient(s) for the variable._output_nr: The output number for the variable in the computationgraph.
_grad_fn: The gradient function associated with the variable.is_leaf: Whether the variable is a leaf node in the computationgraph.
- Returns:
A dictionary with a success message if the variable is created.
- Return type:
- Raises:
KeyError – If required keys are missing from params.
RuntimeError – If the variable creation fails for any reason.
- async rpc_heartbeat(params)[source]#
Handle the ‘heartbeat’ JSON-RPC notification from the server.
This method is called when the server sends a heartbeat notification for a long-running operation. It updates the last heartbeat timestamp for the corresponding request ID, allowing the client to reset its timeout and avoid prematurely timing out while the server is still processing the request.
- Parameters:
params (dict) –
A dictionary with keys:
id: The request ID (str) associated with the long-running operation.
- async rpc_run_callable(params)[source]#
Handle the ‘run_callable’ JSON-RPC method from the server.
This method is invoked when the server sends a JSON-RPC request with the method “run_callable”. It extracts callable details from the provided parameters, executes the callable from the registry, and returns a response containing the result. The response is expected to be JSON-serializable.
- Parameters:
params (dict) –
A dictionary containing:
callable_id: A unique identifier for the callable.args: Positional arguments (as a list or tuple) for the callable.kwargs: Keyword arguments for the callable.
- Returns:
A dictionary with the following structure:
- {
“message”: “Ok”, “data”: <result of executing the callable>
}
- Return type:
- Raises:
KeyError – If required keys are missing in params.
TypeError – If the result of the callable is not JSON-serializable.
ValueError – If the callable execution fails due to invalid parameters.
RuntimeError – For any other exception encountered during callable execution.
- async rpc_update_model(params)[source]#
Handle the ‘update_model’ JSON-RPC method from the server.
This method updates a specific field of a registered LM model instance in response to a server notification. It uses the provided parameters to identify the LM model and the field to update.
- Parameters:
params (dict) –
A dictionary with keys:
model_id: The unique identifier of the LM model.field: The field name to update.value: The new value to set for the field.
- Returns:
A dictionary with a success message if the update is successful.
- Return type:
- Raises:
KeyError – If required keys are missing from params.
RuntimeError – If the update fails for any reason.
- async rpc_update_variable(params)[source]#
Handle the ‘update_variable’ JSON-RPC method from the server.
This method updates a specific field of a registered Variable instance in response to a server notification. It uses the provided parameters to identify the variable and the field to update.
- Parameters:
params (dict) –
A dictionary with keys:
variable_id: The unique identifier of the Variable.field: The field name to update.value: The new value to set for the field.
- Returns:
A dictionary with a success message if the update is successful.
- Return type:
- Raises:
KeyError – If required keys are missing from params.
RuntimeError – If the update fails for any reason.