Source code for afnio.models.model

import copy
import logging
from abc import ABC
from typing import Dict, List, Optional, Union

from afnio.logging_config import configure_logging
from afnio.tellurio._client_manager import get_default_clients
from afnio.tellurio._eventloop import run_in_background_loop
from afnio.tellurio._model_registry import register_model
from afnio.tellurio.consent import check_consent

INITIAL_COST = {"cost": {"amount": 0.0, "currency": "USD"}}


# Configure logging
configure_logging()
logger = logging.getLogger(__name__)


[docs] class BaseModel(ABC): """ An abstraction for a model. """ def __init__( self, provider: str = None, config: Optional[dict] = None, usage: Optional[dict] = None, ): self.provider = provider self._config = config or {} self._usage = usage or {} self._usage.update(copy.deepcopy(INITIAL_COST)) self.model_id = None # Request user consent before sending sensitive info to the server check_consent() try: # Get the singleton websocket client _, ws_client = get_default_clients() payload = { "class_type": self.__class__.__name__, "provider": self.provider, "config": self.get_config(), "usage": self.get_usage(), } response = run_in_background_loop(ws_client.call("create_model", payload)) if "error" in response: raise RuntimeError( response["error"]["data"].get("exception", response["error"]) ) logger.debug(f"LM model created and shared with the server: {self!r}") model_id = response["result"].get("model_id") if not model_id: raise RuntimeError( f"Server did not return a model_id " f"for payload: {payload!r}, response: {response!r}" ) self.model_id = model_id register_model(self) except Exception as e: logger.error(f"Failed to share LM model with the server: {e}") raise
[docs] def get_provider(self) -> Optional[str]: """Returns the model provider name.""" return self.provider
[docs] def get_config(self) -> Dict[str, Union[str, float, int]]: """ Returns the model configuration. This includes the model name, temperature, max tokens, and other parameters that are used to configure the model's behavior. Returns: dict: A dictionary containing the model's configuration parameters. """ return self._config
[docs] def update_usage(self, usage: Dict[str, int], model_name: str = None) -> None: """ Updates the internal token usage statistics and cost. Each model provider (e.g., OpenAI, Anthropic) may have a different usage format. This method should be implemented by subclasses to ensure correct parsing and aggregation of token usage. Behavior: - If `model_name` is provided, the method dynamically calculates and updates the cost based on the usage metrics and the pricing for the specified model. - If `model_name` is None, the method copies the cost value directly from the `usage` dictionary (if present), which is typically used when restoring state from a checkpoint. Args: usage (Dict[str, int]): A dictionary containing token usage metrics, such as `prompt_tokens`, `completion_tokens`, and `total_tokens`. model_name (str, optional): The name of the model for which the usage is being updated. If None, cost is copied from usage if available. Raises: NotImplementedError: If called on the base class without an implementation. """ raise NotImplementedError
[docs] def get_usage(self) -> Dict[str, int]: """ Retrieves the current token usage statistics and cost (in USD). Returns: Dict[str, int]: A dictionary containing cumulative token usage statistics since the model instance was initialized. Example: >>> model.get_usage() { 'prompt_tokens': 1500, 'completion_tokens': 1200, 'total_tokens': 2700, 'cost': {'amount': 12.00, 'currency': 'USD'} } """ return self._usage.copy()
[docs] def clear_usage(self) -> None: """ Clears the token usage statistics. This resets all numerical values in the usage dictionary to zero (including nested values), while preserving the dictionary structure. """ try: # Get the singleton websocket client _, ws_client = get_default_clients() payload = { "model_id": self.model_id, } response = run_in_background_loop( ws_client.call("clear_model_usage", payload) ) if "error" in response: raise RuntimeError( response["error"]["data"].get("exception", response["error"]) ) model_id = response["result"].get("model_id") if not model_id: raise RuntimeError( f"Server did not return a model_id " f"for payload: {payload!r}, response: {response!r}" ) logger.debug(f"LM model usage cleared on the server: {self!r}") except Exception as e: logger.error(f"Failed to clear LM model usage on the server: {e}") raise
def __deepcopy__(self, memo): """ Custom deepcopy to save only the class type and metadata like usage. """ if id(self) in memo: return memo[id(self)] # Save only the class type and any necessary metadata (e.g., usage details) cls_copy = { "class_type": self.__class__.__name__, "provider": self.provider, "usage": self.get_usage(), } # Store the copied object in memo before returning it memo[id(self)] = cls_copy return cls_copy
# TODO: handle caching
[docs] class TextCompletionModel(BaseModel): """ An abstraction for a language model that accepts a prompt composed of a single text input and generates a textual completion. """ def __init__(self, provider: str = None, **kwargs): super().__init__(provider=provider, **kwargs)
[docs] async def acomplete(self, prompt: str, **kwargs) -> str: """ Asynchronous method to generate a completion for the given prompt. Args: prompt (str): The input text for which the model should generate a completion. **kwargs: Additional parameters to configure the model's behavior during chat completion. This may include options such as: - model (str): The model to use (e.g., "gpt-4o"). - temperature (float): Amount of randomness injected into the response. - max_completion_tokens (int): Maximum number of tokens to generate. - etc. For a complete list of supported parameters for each model, refer to the respective API documentation. Returns: str: A string containing the generated completion. """ raise NotImplementedError
[docs] def complete(self, prompt: str, **kwargs) -> str: """ Synchronous method to generate a completion for the given prompt. Args: prompt (str): The input text for which the model should generate a completion. **kwargs: Additional parameters to configure the model's behavior during chat completion. This may include options such as: - model (str): The model to use (e.g., "gpt-4o"). - temperature (float): Amount of randomness injected into the response. - max_completion_tokens (int): Maximum number of tokens to generate. - etc. For a complete list of supported parameters for each model, refer to the respective API documentation. Returns: str: A string containing the generated completion. """ raise NotImplementedError
# TODO: handle caching
[docs] class ChatCompletionModel(BaseModel): """ An abstraction for a language model that accepts a prompt composed of an array of messages containing instructions for the model. Each message can have a different role, influencing how the model interprets the input. """ def __init__(self, provider: str = None, **kwargs): super().__init__(provider=provider, **kwargs) # TODO: Add link to `API documentation` for kwargs of each supported model
[docs] async def achat(self, messages: List[Dict[str, str]], **kwargs) -> str: """ Asynchronous method to handle chat-based interactions with the model. Args: messages (List[Dict[str, str]]): A list of messages, where each message is represented as a dictionary with "role" (e.g., "user", "system") and "content" (the text of the message). **kwargs: Additional parameters to configure the model's behavior during chat completion. This may include options such as: - model (str): The model to use (e.g., "gpt-4o"). - temperature (float): Amount of randomness injected into the response. - max_completion_tokens (int): Maximum number of tokens to generate. - etc. For a complete list of supported parameters for each model, refer to the respective API documentation. Returns: str: A string containing the model's response to the chat messages. """ raise NotImplementedError
[docs] def chat(self, messages: List[Dict[str, str]], **kwargs) -> str: """ Synchronous method to handle chat-based interactions with the model. Args: messages (List[Dict[str, str]]): A list of messages, where each message is represented as a dictionary with "role" (e.g., "user", "system") and "content" (the text of the message). **kwargs: Additional parameters to configure the model's behavior during chat completion. This may include options such as: - model (str): The model to use (e.g., "gpt-4o"). - temperature (float): Amount of randomness injected into the response. - max_completion_tokens (int): Maximum number of tokens to generate. - etc. For a complete list of supported parameters for each model, refer to the respective API documentation. Returns: str: A string containing the model's response to the chat messages. """ raise NotImplementedError
# TODO: handle caching
[docs] class EmbeddingModel(BaseModel): """ An abstraction for a model that generates embeddings for input texts. """ def __init__(self, provider: str = None, **kwargs): super().__init__(provider=provider, **kwargs)
[docs] async def aembed(self, input: List[str], **kwargs) -> List[List[float]]: """ Asynchronous method to generate embeddings for the given input texts. Args: input (List[str]): A list of input strings for which embeddings should be generated. **kwargs: Additional parameters to configure the model's behavior during chat completion. This may include options such as: - model (str): The model to use (e.g., "gpt-4o"). - temperature (float): Amount of randomness injected into the response. - max_completion_tokens (int): Maximum number of tokens to generate. - etc. For a complete list of supported parameters for each model, refer to the respective API documentation. Returns: List[List[float]]: A list of embeddings, where each embedding is represented as a list of floats corresponding to the input strings. """ raise NotImplementedError
[docs] def embed(self, input: List[str], **kwargs) -> List[List[float]]: """ Synchronous method to generate embeddings for the given input texts. Args: input (List[str]): A list of input strings for which embeddings should be generated. **kwargs: Additional parameters to configure the model's behavior during chat completion. This may include options such as: - model (str): The model to use (e.g., "gpt-4o"). - temperature (float): Amount of randomness injected into the response. - max_completion_tokens (int): Maximum number of tokens to generate. - etc. For a complete list of supported parameters for each model, refer to the respective API documentation. Returns: List[List[float]]: A list of embeddings, where each embedding is represented as a list of floats corresponding to the input strings. """ raise NotImplementedError