Source code for afnio._model_client

import logging
import os
from typing import Any, Dict, List, Optional

from afnio.logging_config import configure_logging
from afnio.models import ChatCompletionModel
from afnio.models.openai import OpenAI
from afnio.tellurio._client_manager import get_default_clients
from afnio.tellurio._eventloop import run_in_background_loop

# Configure logging
configure_logging()
logger = logging.getLogger(__name__)


class ModelClientSingleton:
    _instance = None
    _client: ChatCompletionModel = None
    _provider: str = None
    _model: str = None
    _client_args: Dict[str, Any] = {}
    _completion_args: Dict[str, Any] = {}

    def __new__(cls):
        if cls._instance is None:
            cls._instance = super(ModelClientSingleton, cls).__new__(cls)
        return cls._instance

    def _initialize(
        self,
        provider: str,
        model: str,
        client_args: Optional[Dict[str, Any]] = None,
        completion_args: Optional[Dict[str, Any]] = None,
    ):
        # Re-initialize only if provider or model changes
        if (
            self._client is None
            or provider != self._provider
            or model != self._model
            or client_args != self._client_args
            or completion_args != self._completion_args
        ):
            self._provider = provider
            self._model = model
            self._client_args = client_args or {}
            self._completion_args = completion_args or {}
            self._client = self._create_client(
                provider, model, self._client_args, self._completion_args
            )

    def _create_client(
        self,
        provider: str,
        model_name: str = None,
        client_args: Optional[Dict[str, Any]] = None,
        completion_args: Optional[Dict[str, Any]] = None,
    ) -> ChatCompletionModel:
        if provider == "openai":
            # TODO: Define default values for each provider. Could use configuration
            #       file(s) with settings for each provider and model
            model = OpenAI(**client_args)
            self._rpc_set_model_singleton(
                model_id=model.model_id,
                model_name=model_name,
                client_args=client_args,
                completion_args=completion_args,
            )
            return model
        # Future providers can be added here
        else:
            raise ValueError(f"Unsupported provider: {provider}.")

    def _rpc_set_model_singleton(
        self,
        model_id: str = None,
        model_name: str = None,
        client_args: Optional[Dict[str, Any]] = None,
        completion_args: Optional[Dict[str, Any]] = None,
    ):
        try:
            # Get the singleton websocket client
            _, ws_client = get_default_clients()

            payload = {
                "model_id": model_id,
                "model_name": model_name,
                "client_args": client_args,
                "completion_args": completion_args,
            }
            response = run_in_background_loop(
                ws_client.call("set_model_singleton", payload)
            )
            if "error" in response:
                raise RuntimeError(
                    response["error"]["data"].get("exception", response["error"])
                )

            # Check server response
            if (
                response["result"]["model_id"] != model_id
                or response["result"]["completion_args"] != completion_args
            ):
                raise RuntimeError(
                    f"Server response mismatch: (received {response['result']!r}, "
                    f"but expected model_id={model_id!r}, "
                    f"completion_args={completion_args!r})"
                )
            logger.debug(
                f"Model singleton set on server and confirmed: "
                f"model_id={model_id!r}, completion_args={completion_args!r}"
            )
        except Exception as e:
            logger.exception(f"Failed to set model singleton on server: {e}")
            raise

    def chat(self, messages: List[Dict[str, str]], **override_kwargs):
        if self._client is None:
            raise RuntimeError(
                "Model client is not set. Call `set_backward_model_client` first."
            )

        # Merge completion_args with overrides
        kwargs = {**self._completion_args, **override_kwargs}
        model = kwargs.pop("model", self._model)

        return self._client.chat(messages=messages, model=model, **kwargs)

    def get_provider(self) -> Optional[str]:
        """Returns the model provider name."""
        return self._client.get_provider()

    def get_usage(self) -> Dict[str, int]:
        """
        Retrieves the current token usage statistics.

        Returns:
            Dict[str, int]: A dictionary containing cumulative token usage
                statistics since the model instance was initialized.

        Example:
            >>> model = get_backward_model_client()
            >>> model.get_usage()
            {'prompt_tokens': 1500, 'completion_tokens': 1200, 'total_tokens': 2700}
        """
        return self._client.get_usage()

    def clear_usage(self) -> None:
        """
        Clears the token usage statistics.

        This resets all numerical values in the usage dictionary to zero (including
        nested values), while preserving the dictionary structure.
        """
        self._client.clear_usage()


# Global singleton instance
_model_singleton = ModelClientSingleton()


[docs] def set_backward_model_client( model_path: str = "openai/gpt-4o", client_args: Optional[Dict[str, Any]] = None, completion_args: Optional[Dict[str, Any]] = None, ): """ Set the global model client for backward operations. Args: model_path (str): Path in the format ``provider/model_name`` (e.g., ``"openai/gpt-4o"``). Default: ``"openai/gpt-4o"``. client_args (Dict): Arguments to initialize the model client such as: - ``api_key`` (str): The client API key. - ``organization`` (str): The organization to bill. - ``base_url`` (str): The model base endpoint URL (useful when models are behind a proxy). - etc. completion_args (Dict): Arguments to pass to ``achat()`` during usage such as: - ``model`` (str): The model to use (e.g., ``gpt-4o``). - ``temperature`` (float): Amount of randomness injected into the response. - ``max_completion_tokens`` (int): Maximum number of tokens to generate. - etc. .. note:: For a complete list of supported ``client_args`` and ``completion_args`` for each model, refer to the respective API documentation. """ try: provider, model = model_path.split("/", 1) except ValueError: raise ValueError("`model_path` must be in the format 'provider/model'") # Ensure client_args is a dict if client_args is None: client_args = {} # Set api_key to value in client_args if present, else from env or None if provider == "openai": client_args["api_key"] = client_args.get("api_key", os.getenv("OPENAI_API_KEY")) else: raise ValueError(f"Unsupported provider: {provider}.") _model_singleton._initialize(provider, model, client_args, completion_args)
[docs] def get_backward_model_client() -> ModelClientSingleton: """ Retrieve the global model client singleton. Raises: RuntimeError: If no model client is set globally. Returns: ModelClientSingleton: The global model client. """ if _model_singleton._client is None: raise RuntimeError( "No global model client set for backward pass. " "Call `set_backward_model_client` to define one." ) return _model_singleton