Source code for afnio.models.openai

import copy
import json
import os
from typing import Dict, Iterable, List, Mapping, Optional, Union

import httpx
from openai import DEFAULT_MAX_RETRIES, NOT_GIVEN, NotGiven
from openai import AsyncOpenAI as AsyncOpenAICli
from openai import OpenAI as OpenAICli
from openai._types import SequenceNotStr
from openai.types.chat import (
    ChatCompletionAudioParam,
    completion_create_params,
)
from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam
from openai.types.chat.chat_completion_prediction_content_param import (
    ChatCompletionPredictionContentParam,
)
from openai.types.chat.chat_completion_stream_options_param import (
    ChatCompletionStreamOptionsParam,
)
from openai.types.chat.chat_completion_tool_choice_option_param import (
    ChatCompletionToolChoiceOptionParam,
)
from openai.types.chat.chat_completion_tool_union_param import (
    ChatCompletionToolUnionParam,
)
from openai.types.chat_model import ChatModel
from openai.types.completion_usage import CompletionUsage
from openai.types.shared.reasoning_effort import ReasoningEffort
from openai.types.shared_params.metadata import Metadata
from typing_extensions import Literal

from .model import ChatCompletionModel, EmbeddingModel, TextCompletionModel

PROVIDER = "openai"
INITIAL_USAGE = {
    "completion_tokens": 0,
    "prompt_tokens": 0,
    "total_tokens": 0,
    "prompt_tokens_details": {
        "cached_tokens": 0,
        "audio_tokens": 0,
    },
    "completion_tokens_details": {
        "reasoning_tokens": 0,
        "audio_tokens": 0,
        "accepted_prediction_tokens": 0,
        "rejected_prediction_tokens": 0,
    },
}


[docs] class Omit: """In certain situations you need to be able to represent a case where a default value has to be explicitly removed and `None` is not an appropriate substitute, for example: .. code-block:: python # as the default `Content-Type` header is `application/json` that will be sent client.post("/upload/files", files={"file": b"my raw file content"}) # you can't explicitly override the header as it has to be dynamically generated # to look something like: 'multipart/form-data; boundary=0d8382fcf5f8c3be01ca2e11002d2983' client.post(..., headers={"Content-Type": "multipart/form-data"}) # instead you can remove the default `application/json` header by passing Omit client.post(..., headers={"Content-Type": Omit()}) """ # noqa: E501 def __bool__(self) -> Literal[False]: return False
Headers = Mapping[str, Union[str, Omit]] Query = Mapping[str, object] Body = object
[docs] class OpenAI( TextCompletionModel, ChatCompletionModel, EmbeddingModel, OpenAICli, ): """ OpenAI synchronous client to perform multiple language model operations. """ # class-level stubs so Sphinx/autodoc can inspect these attributes safely api_key: Optional[str] = None organization: Optional[str] = None project: Optional[str] = None webhook_secret: Optional[str] = None websocket_base_url: Optional[Union[str, httpx.URL]] = None def __init__( self, api_key: Optional[str] = None, organization: Optional[str] = None, project: Optional[str] = None, base_url: Optional[Union[str, httpx.URL]] = None, websocket_base_url: Optional[Union[str, httpx.URL]] = None, timeout: Union[float, httpx.Timeout, None, NotGiven] = NOT_GIVEN, max_retries: int = DEFAULT_MAX_RETRIES, default_headers: Optional[Mapping[str, str]] = None, default_query: Optional[Mapping[str, object]] = None, http_client: Optional[httpx.Client] = None, ): usage = copy.deepcopy(INITIAL_USAGE) # Ensure a fresh copy # Validate and build config config = { "api_key": api_key or os.getenv("OPENAI_API_KEY"), "organization": organization, "project": project, "base_url": base_url, "websocket_base_url": websocket_base_url, "timeout": timeout, "max_retries": max_retries, "default_headers": default_headers, "default_query": default_query, } # Remove None and NOT_GIVEN values config = { k: v for k, v in config.items() if v is not None and v is not NOT_GIVEN } # Validate serializability for k, v in config.items(): _validate_config_param(k, v) self._client = OpenAICli( api_key=api_key or os.getenv("OPENAI_API_KEY"), organization=organization, project=project, base_url=base_url, websocket_base_url=websocket_base_url, timeout=timeout, max_retries=max_retries, default_headers=default_headers, default_query=default_query, http_client=http_client, ) super().__init__(provider=PROVIDER, config=config, usage=usage) # TODO: Finalize implementation
[docs] def complete(self, prompt: str) -> str: """ Synchronous method to generate a completion for the given prompt. Args: prompt (str): The input text for which the model should generate a completion. Returns: str: A string containing the generated completion. """ raise NotImplementedError
[docs] def chat( self, *, messages: Iterable[ChatCompletionMessageParam], model: Union[str, ChatModel], audio: Union[Optional[ChatCompletionAudioParam], NotGiven] = NOT_GIVEN, frequency_penalty: Union[Optional[float], NotGiven] = NOT_GIVEN, function_call: Union[ completion_create_params.FunctionCall, NotGiven ] = NOT_GIVEN, functions: Union[ Iterable[completion_create_params.Function], NotGiven ] = NOT_GIVEN, logit_bias: Union[Optional[Dict[str, int]], NotGiven] = NOT_GIVEN, logprobs: Union[Optional[bool], NotGiven] = NOT_GIVEN, max_completion_tokens: Union[Optional[int], NotGiven] = NOT_GIVEN, max_tokens: Union[Optional[int], NotGiven] = NOT_GIVEN, metadata: Union[Optional[Metadata], NotGiven] = NOT_GIVEN, modalities: Union[ Optional[List[Literal["text", "audio"]]], NotGiven ] = NOT_GIVEN, n: Union[Optional[int], NotGiven] = NOT_GIVEN, parallel_tool_calls: Union[bool, NotGiven] = NOT_GIVEN, prediction: Union[ Optional[ChatCompletionPredictionContentParam], NotGiven ] = NOT_GIVEN, presence_penalty: Union[Optional[float], NotGiven] = NOT_GIVEN, prompt_cache_key: Union[str, NotGiven] = NOT_GIVEN, reasoning_effort: Union[ReasoningEffort, NotGiven] = NOT_GIVEN, response_format: Union[ completion_create_params.ResponseFormat, NotGiven ] = NOT_GIVEN, safety_identifier: Union[str, NotGiven] = NOT_GIVEN, seed: Union[Optional[int], NotGiven] = NOT_GIVEN, service_tier: Union[ Optional[Literal["auto", "default", "flex", "scale", "priority"]], NotGiven ] = NOT_GIVEN, stop: Union[ Union[Optional[str], SequenceNotStr[str], None], NotGiven ] = NOT_GIVEN, store: Union[Optional[bool], NotGiven] = NOT_GIVEN, # TODO: `stream` can be useful during inference, but forbid during training or backpropagation stream: Union[Optional[Literal[False]], NotGiven] = NOT_GIVEN, stream_options: Union[ Optional[ChatCompletionStreamOptionsParam], NotGiven ] = NOT_GIVEN, temperature: Union[Optional[float], NotGiven] = NOT_GIVEN, tool_choice: Union[ChatCompletionToolChoiceOptionParam, NotGiven] = NOT_GIVEN, tools: Union[Iterable[ChatCompletionToolUnionParam], NotGiven] = NOT_GIVEN, top_logprobs: Union[Optional[int], NotGiven] = NOT_GIVEN, top_p: Union[Optional[float], NotGiven] = NOT_GIVEN, user: Union[str, NotGiven] = NOT_GIVEN, verbosity: Union[ Optional[Literal["low", "medium", "high"]], NotGiven ] = NOT_GIVEN, web_search_options: Union[ completion_create_params.WebSearchOptions, NotGiven ] = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Optional[Headers] = None, extra_query: Optional[Query] = None, extra_body: Optional[Body] = None, timeout: Optional[Union[float, httpx.Timeout, NotGiven]] = NOT_GIVEN, ) -> str: """Synchronously creates a model response for the given chat conversation. Learn more in the [text generation](https://platform.openai.com/docs/guides/text-generation), [vision](https://platform.openai.com/docs/guides/vision), and [audio](https://platform.openai.com/docs/guides/audio) guides. Parameter support can differ depending on the model used to generate the response, particularly for newer reasoning models. Parameters that are only supported for reasoning models are noted below. For the current state of unsupported parameters in reasoning models, [refer to the reasoning guide](https://platform.openai.com/docs/guides/reasoning). Args: messages: A list of messages comprising the conversation so far. Depending on the [model](https://platform.openai.com/docs/models) you use, different message types (modalities) are supported, like [text](https://platform.openai.com/docs/guides/text-generation), [images](https://platform.openai.com/docs/guides/vision), and [audio](https://platform.openai.com/docs/guides/audio). model: Model ID used to generate the response, like `gpt-4o` or `o3`. OpenAI offers a wide range of models with different capabilities, performance characteristics, and price points. Refer to the [model guide](https://platform.openai.com/docs/models) to browse and compare available models. audio: Parameters for audio output. Required when audio output is requested with `modalities: ["audio"]`. [Learn more](https://platform.openai.com/docs/guides/audio). frequency_penalty: Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. function_call: Deprecated in favor of `tool_choice`. Controls which (if any) function is called by the model. `none` means the model will not call a function and instead generates a message. `auto` means the model can pick between generating a message or calling a function. Specifying a particular function via `{"name": "my_function"}` forces the model to call that function. `none` is the default when no functions are present. `auto` is the default if functions are present. functions: Deprecated in favor of `tools`. A list of functions the model may generate JSON inputs for. logit_bias: Modify the likelihood of specified tokens appearing in the completion. Accepts a JSON object that maps tokens (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically, the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should result in a ban or exclusive selection of the relevant token. logprobs: Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each output token returned in the `content` of `message`. max_completion_tokens: An upper bound for the number of tokens that can be generated for a completion, including visible output tokens and [reasoning tokens](https://platform.openai.com/docs/guides/reasoning). max_tokens: The maximum number of [tokens](/tokenizer) that can be generated in the chat completion. This value can be used to control [costs](https://openai.com/api/pricing/) for text generated via API. This value is now deprecated in favor of `max_completion_tokens`, and is not compatible with [o-series models](https://platform.openai.com/docs/guides/reasoning). metadata: Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format, and querying for objects via API or the dashboard. Keys are strings with a maximum length of 64 characters. Values are strings with a maximum length of 512 characters. modalities: Output types that you would like the model to generate. Most models are capable of generating text, which is the default: `["text"]` The `gpt-4o-audio-preview` model can also be used to [generate audio](https://platform.openai.com/docs/guides/audio). To request that this model generate both text and audio responses, you can use: `["text", "audio"]` n: How many chat completion choices to generate for each input message. Note that you will be charged based on the number of generated tokens across all of the choices. Keep `n` as `1` to minimize costs. parallel_tool_calls: Whether to enable [parallel function calling](https://platform.openai.com/docs/guides/function-calling#configuring-parallel-function-calling) during tool use. prediction: Static predicted output content, such as the content of a text file that is being regenerated. presence_penalty: Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics. prompt_cache_key: Used by OpenAI to cache responses for similar requests to optimize your cache hit rates. Replaces the `user` field. [Learn more](https://platform.openai.com/docs/guides/prompt-caching). reasoning_effort: Constrains effort on reasoning for [reasoning models](https://platform.openai.com/docs/guides/reasoning). Currently supported values are `minimal`, `low`, `medium`, and `high`. Reducing reasoning effort can result in faster responses and fewer tokens used on reasoning in a response. response_format: An object specifying the format that the model must output. Setting to `{ "type": "json_schema", "json_schema": {...} }` enables Structured Outputs which ensures the model will match your supplied JSON schema. Learn more in the [Structured Outputs guide](https://platform.openai.com/docs/guides/structured-outputs). Setting to `{ "type": "json_object" }` enables the older JSON mode, which ensures the message the model generates is valid JSON. Using `json_schema` is preferred for models that support it. safety_identifier: A stable identifier used to help detect users of your application that may be violating OpenAI's usage policies. The IDs should be a string that uniquely identifies each user. We recommend hashing their username or email address, in order to avoid sending us any identifying information. [Learn more](https://platform.openai.com/docs/guides/safety-best-practices#safety-identifiers). seed: This feature is in Beta. If specified, our system will make a best effort to sample deterministically, such that repeated requests with the same `seed` and parameters should return the same result. Determinism is not guaranteed, and you should refer to the `system_fingerprint` response parameter to monitor changes in the backend. service_tier: Specifies the processing type used for serving the request. - If set to 'auto', then the request will be processed with the service tier configured in the Project settings. Unless otherwise configured, the Project will use 'default'. - If set to 'default', then the request will be processed with the standard pricing and performance for the selected model. - If set to '[flex](https://platform.openai.com/docs/guides/flex-processing)' or '[priority](https://openai.com/api-priority-processing/)', then the request will be processed with the corresponding service tier. - When not set, the default behavior is 'auto'. When the `service_tier` parameter is set, the response body will include the `service_tier` value based on the processing mode actually used to serve the request. This response value may be different from the value set in the parameter. stop: Not supported with latest reasoning models `o3` and `o4-mini`. Up to 4 sequences where the API will stop generating further tokens. The returned text will not contain the stop sequence. store: Whether or not to store the output of this chat completion request for use in our [model distillation](https://platform.openai.com/docs/guides/distillation) or [evals](https://platform.openai.com/docs/guides/evals) products. Supports text and image inputs. Note: image inputs over 8MB will be dropped. stream: If set to true, the model response data will be streamed to the client as it is generated using [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format). See the [Streaming section below](https://platform.openai.com/docs/api-reference/chat/streaming) for more information, along with the [streaming responses](https://platform.openai.com/docs/guides/streaming-responses) guide for more information on how to handle the streaming events. stream_options: Options for streaming response. Only set this when you set `stream: true`. temperature: What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. We generally recommend altering this or `top_p` but not both. tool_choice: Controls which (if any) tool is called by the model. `none` means the model will not call any tool and instead generates a message. `auto` means the model can pick between generating a message or calling one or more tools. `required` means the model must call one or more tools. Specifying a particular tool via `{"type": "function", "function": {"name": "my_function"}}` forces the model to call that tool. `none` is the default when no tools are present. `auto` is the default if tools are present. tools: A list of tools the model may call. You can provide either [custom tools](https://platform.openai.com/docs/guides/function-calling#custom-tools) or [function tools](https://platform.openai.com/docs/guides/function-calling). top_logprobs: An integer between 0 and 20 specifying the number of most likely tokens to return at each token position, each with an associated log probability. `logprobs` must be set to `true` if this parameter is used. top_p: An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. We generally recommend altering this or `temperature` but not both. user: This field is being replaced by `safety_identifier` and `prompt_cache_key`. Use `prompt_cache_key` instead to maintain caching optimizations. A stable identifier for your end-users. Used to boost cache hit rates by better bucketing similar requests and to help OpenAI detect and prevent abuse. [Learn more](https://platform.openai.com/docs/guides/safety-best-practices#safety-identifiers). verbosity: Constrains the verbosity of the model's response. Lower values will result in more concise responses, while higher values will result in more verbose responses. Currently supported values are `low`, `medium`, and `high`. web_search_options: This tool searches the web for relevant results to use in a response. Learn more about the [web search tool](https://platform.openai.com/docs/guides/tools-web-search?api-mode=chat). extra_headers: Send extra headers extra_query: Add additional query parameters to the request extra_body: Add additional JSON properties to the request timeout: Override the client-level default timeout for this request, in seconds """ # noqa: E501 # TODO: handle `n > 1`` that returns multiple completions if isinstance(n, int) and n > 1: raise ValueError("n > 1 is not supported for chat completions.") response = self._client.chat.completions.create( messages=messages, model=model, audio=audio, frequency_penalty=frequency_penalty, function_call=function_call, functions=functions, logit_bias=logit_bias, logprobs=logprobs, max_completion_tokens=max_completion_tokens, max_tokens=max_tokens, metadata=metadata, modalities=modalities, n=n, parallel_tool_calls=parallel_tool_calls, prediction=prediction, presence_penalty=presence_penalty, prompt_cache_key=prompt_cache_key, reasoning_effort=reasoning_effort, response_format=response_format, safety_identifier=safety_identifier, seed=seed, service_tier=service_tier, stop=stop, store=store, stream=stream, stream_options=stream_options, temperature=temperature, tool_choice=tool_choice, tools=tools, top_logprobs=top_logprobs, top_p=top_p, user=user, verbosity=verbosity, web_search_options=web_search_options, extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout, ) # Update usage statistics if hasattr(response, "usage") and response.usage: self.update_usage(response.usage, model) return response.choices[0].message.content
# TODO: Finalize implementation
[docs] def embed(self, input: List[str]) -> List[List[float]]: """ Synchronous method to generate embeddings for the given input texts. Args: input (List[str]): A list of input strings for which embeddings should be generated. Returns: List[List[float]]: A list of embeddings, where each embedding is represented as a list of floats corresponding to the input strings. """ raise NotImplementedError
[docs] def update_usage(self, usage: CompletionUsage, model_name: str = None) -> None: """Updates the internal usage counters with values from a new API response. Args: usage (CompletionUsage): The usage object returned by the OpenAI API. model_name (str, optional): The name of the model for which the usage is being updated. If None, cost is copied from usage if available. """ if not hasattr(self, "_usage"): self._usage.update(copy.deepcopy(INITIAL_USAGE)) # Ensure a fresh copy # Ensure we convert CompletionUsage to dict properly if isinstance(usage, CompletionUsage): usage = usage.model_dump() # Update core token usage fields self._usage["completion_tokens"] += usage.get("completion_tokens", 0) self._usage["prompt_tokens"] += usage.get("prompt_tokens", 0) self._usage["total_tokens"] += usage.get("total_tokens", 0) # Update prompt tokens details prompt_tokens_details = usage.get("prompt_tokens_details", {}) self._usage["prompt_tokens_details"][ "cached_tokens" ] += prompt_tokens_details.get("cached_tokens", 0) self._usage["prompt_tokens_details"][ "audio_tokens" ] += prompt_tokens_details.get("audio_tokens", 0) # Update completion tokens details completion_tokens_details = usage.get("completion_tokens_details", {}) self._usage["completion_tokens_details"][ "reasoning_tokens" ] += completion_tokens_details.get("reasoning_tokens", 0) self._usage["completion_tokens_details"][ "audio_tokens" ] += completion_tokens_details.get("audio_tokens", 0) self._usage["completion_tokens_details"][ "accepted_prediction_tokens" ] += completion_tokens_details.get("accepted_prediction_tokens", 0) self._usage["completion_tokens_details"][ "rejected_prediction_tokens" ] += completion_tokens_details.get("rejected_prediction_tokens", 0) # Update cost if model_name is not None: pricing = _get_pricing_for_model(self.provider, model_name) cost = _calculate_cost(usage, pricing) self._usage["cost"]["amount"] += cost else: # If cost is present in usage, copy it directly if "cost" in usage and "amount" in usage["cost"]: self._usage["cost"]["amount"] = usage["cost"]["amount"]
[docs] class AsyncOpenAI( TextCompletionModel, ChatCompletionModel, EmbeddingModel, AsyncOpenAICli, ): """ OpenAI asynchronous client to perform multiple language model operations. """ # class-level stubs so Sphinx/autodoc can inspect these attributes safely api_key: Optional[str] = None organization: Optional[str] = None project: Optional[str] = None webhook_secret: Optional[str] = None websocket_base_url: Optional[Union[str, httpx.URL]] = None def __init__( self, api_key: Optional[str] = None, organization: Optional[str] = None, project: Optional[str] = None, base_url: Optional[Union[str, httpx.URL]] = None, websocket_base_url: Optional[Union[str, httpx.URL]] = None, timeout: Union[float, httpx.Timeout, None, NotGiven] = NOT_GIVEN, max_retries: int = DEFAULT_MAX_RETRIES, default_headers: Optional[Mapping[str, str]] = None, default_query: Optional[Mapping[str, object]] = None, http_client: Optional[httpx.AsyncClient] = None, ): usage = copy.deepcopy(INITIAL_USAGE) # Ensure a fresh copy # Validate and build config config = { "api_key": api_key or os.getenv("OPENAI_API_KEY"), "organization": organization, "project": project, "base_url": base_url, "websocket_base_url": websocket_base_url, "timeout": timeout, "max_retries": max_retries, "default_headers": default_headers, "default_query": default_query, } # Remove None and NOT_GIVEN values config = { k: v for k, v in config.items() if v is not None and v is not NOT_GIVEN } # Validate serializability for k, v in config.items(): _validate_config_param(k, v) self._aclient = AsyncOpenAICli( api_key=api_key or os.getenv("OPENAI_API_KEY"), organization=organization, project=project, base_url=base_url, websocket_base_url=websocket_base_url, timeout=timeout, max_retries=max_retries, default_headers=default_headers, default_query=default_query, http_client=http_client, ) super().__init__(provider=PROVIDER, config=config, usage=usage) # TODO: Finalize implementation
[docs] async def acomplete(self, prompt: str) -> str: """ Asynchronous method to generate a completion for the given prompt. Args: prompt (str): The input text for which the model should generate a completion. Returns: str: A string containing the generated completion. """ raise NotImplementedError
[docs] async def achat( self, *, messages: Iterable[ChatCompletionMessageParam], model: Union[str, ChatModel], audio: Union[Optional[ChatCompletionAudioParam], NotGiven] = NOT_GIVEN, frequency_penalty: Union[Optional[float], NotGiven] = NOT_GIVEN, function_call: Union[ completion_create_params.FunctionCall, NotGiven ] = NOT_GIVEN, functions: Union[ Iterable[completion_create_params.Function], NotGiven ] = NOT_GIVEN, logit_bias: Union[Optional[Dict[str, int]], NotGiven] = NOT_GIVEN, logprobs: Union[Optional[bool], NotGiven] = NOT_GIVEN, max_completion_tokens: Union[Optional[int], NotGiven] = NOT_GIVEN, max_tokens: Union[Optional[int], NotGiven] = NOT_GIVEN, metadata: Union[Optional[Metadata], NotGiven] = NOT_GIVEN, modalities: Union[ Optional[List[Literal["text", "audio"]]], NotGiven ] = NOT_GIVEN, n: Union[Optional[int], NotGiven] = NOT_GIVEN, parallel_tool_calls: Union[bool, NotGiven] = NOT_GIVEN, prediction: Union[ Optional[ChatCompletionPredictionContentParam], NotGiven ] = NOT_GIVEN, presence_penalty: Union[Optional[float], NotGiven] = NOT_GIVEN, prompt_cache_key: Union[str, NotGiven] = NOT_GIVEN, reasoning_effort: Union[ReasoningEffort, NotGiven] = NOT_GIVEN, response_format: Union[ completion_create_params.ResponseFormat, NotGiven ] = NOT_GIVEN, safety_identifier: Union[str, NotGiven] = NOT_GIVEN, seed: Union[Optional[int], NotGiven] = NOT_GIVEN, service_tier: Union[ Optional[Literal["auto", "default", "flex", "scale", "priority"]], NotGiven ] = NOT_GIVEN, stop: Union[ Union[Optional[str], SequenceNotStr[str], None], NotGiven ] = NOT_GIVEN, store: Union[Optional[bool], NotGiven] = NOT_GIVEN, # TODO: `stream` can be useful during inference, but forbid during training or backpropagation stream: Union[Optional[Literal[False]], NotGiven] = NOT_GIVEN, stream_options: Union[ Optional[ChatCompletionStreamOptionsParam], NotGiven ] = NOT_GIVEN, temperature: Union[Optional[float], NotGiven] = NOT_GIVEN, tool_choice: Union[ChatCompletionToolChoiceOptionParam, NotGiven] = NOT_GIVEN, tools: Union[Iterable[ChatCompletionToolUnionParam], NotGiven] = NOT_GIVEN, top_logprobs: Union[Optional[int], NotGiven] = NOT_GIVEN, top_p: Union[Optional[float], NotGiven] = NOT_GIVEN, user: Union[str, NotGiven] = NOT_GIVEN, verbosity: Union[ Optional[Literal["low", "medium", "high"]], NotGiven ] = NOT_GIVEN, web_search_options: Union[ completion_create_params.WebSearchOptions, NotGiven ] = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. extra_headers: Optional[Headers] = None, extra_query: Optional[Query] = None, extra_body: Optional[Body] = None, timeout: Optional[Union[float, httpx.Timeout, NotGiven]] = NOT_GIVEN, ) -> str: """Asynchronously creates a model response for the given chat conversation. Learn more in the [text generation](https://platform.openai.com/docs/guides/text-generation), [vision](https://platform.openai.com/docs/guides/vision), and [audio](https://platform.openai.com/docs/guides/audio) guides. Parameter support can differ depending on the model used to generate the response, particularly for newer reasoning models. Parameters that are only supported for reasoning models are noted below. For the current state of unsupported parameters in reasoning models, [refer to the reasoning guide](https://platform.openai.com/docs/guides/reasoning). Args: messages: A list of messages comprising the conversation so far. Depending on the [model](https://platform.openai.com/docs/models) you use, different message types (modalities) are supported, like [text](https://platform.openai.com/docs/guides/text-generation), [images](https://platform.openai.com/docs/guides/vision), and [audio](https://platform.openai.com/docs/guides/audio). model: Model ID used to generate the response, like `gpt-4o` or `o3`. OpenAI offers a wide range of models with different capabilities, performance characteristics, and price points. Refer to the [model guide](https://platform.openai.com/docs/models) to browse and compare available models. audio: Parameters for audio output. Required when audio output is requested with `modalities: ["audio"]`. [Learn more](https://platform.openai.com/docs/guides/audio). frequency_penalty: Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. function_call: Deprecated in favor of `tool_choice`. Controls which (if any) function is called by the model. `none` means the model will not call a function and instead generates a message. `auto` means the model can pick between generating a message or calling a function. Specifying a particular function via `{"name": "my_function"}` forces the model to call that function. `none` is the default when no functions are present. `auto` is the default if functions are present. functions: Deprecated in favor of `tools`. A list of functions the model may generate JSON inputs for. logit_bias: Modify the likelihood of specified tokens appearing in the completion. Accepts a JSON object that maps tokens (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically, the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should result in a ban or exclusive selection of the relevant token. logprobs: Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each output token returned in the `content` of `message`. max_completion_tokens: An upper bound for the number of tokens that can be generated for a completion, including visible output tokens and [reasoning tokens](https://platform.openai.com/docs/guides/reasoning). max_tokens: The maximum number of [tokens](/tokenizer) that can be generated in the chat completion. This value can be used to control [costs](https://openai.com/api/pricing/) for text generated via API. This value is now deprecated in favor of `max_completion_tokens`, and is not compatible with [o-series models](https://platform.openai.com/docs/guides/reasoning). metadata: Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format, and querying for objects via API or the dashboard. Keys are strings with a maximum length of 64 characters. Values are strings with a maximum length of 512 characters. modalities: Output types that you would like the model to generate. Most models are capable of generating text, which is the default: `["text"]` The `gpt-4o-audio-preview` model can also be used to [generate audio](https://platform.openai.com/docs/guides/audio). To request that this model generate both text and audio responses, you can use: `["text", "audio"]` n: How many chat completion choices to generate for each input message. Note that you will be charged based on the number of generated tokens across all of the choices. Keep `n` as `1` to minimize costs. parallel_tool_calls: Whether to enable [parallel function calling](https://platform.openai.com/docs/guides/function-calling#configuring-parallel-function-calling) during tool use. prediction: Static predicted output content, such as the content of a text file that is being regenerated. presence_penalty: Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics. prompt_cache_key: Used by OpenAI to cache responses for similar requests to optimize your cache hit rates. Replaces the `user` field. [Learn more](https://platform.openai.com/docs/guides/prompt-caching). reasoning_effort: Constrains effort on reasoning for [reasoning models](https://platform.openai.com/docs/guides/reasoning). Currently supported values are `minimal`, `low`, `medium`, and `high`. Reducing reasoning effort can result in faster responses and fewer tokens used on reasoning in a response. response_format: An object specifying the format that the model must output. Setting to `{ "type": "json_schema", "json_schema": {...} }` enables Structured Outputs which ensures the model will match your supplied JSON schema. Learn more in the [Structured Outputs guide](https://platform.openai.com/docs/guides/structured-outputs). Setting to `{ "type": "json_object" }` enables the older JSON mode, which ensures the message the model generates is valid JSON. Using `json_schema` is preferred for models that support it. safety_identifier: A stable identifier used to help detect users of your application that may be violating OpenAI's usage policies. The IDs should be a string that uniquely identifies each user. We recommend hashing their username or email address, in order to avoid sending us any identifying information. [Learn more](https://platform.openai.com/docs/guides/safety-best-practices#safety-identifiers). seed: This feature is in Beta. If specified, our system will make a best effort to sample deterministically, such that repeated requests with the same `seed` and parameters should return the same result. Determinism is not guaranteed, and you should refer to the `system_fingerprint` response parameter to monitor changes in the backend. service_tier: Specifies the processing type used for serving the request. - If set to 'auto', then the request will be processed with the service tier configured in the Project settings. Unless otherwise configured, the Project will use 'default'. - If set to 'default', then the request will be processed with the standard pricing and performance for the selected model. - If set to '[flex](https://platform.openai.com/docs/guides/flex-processing)' or '[priority](https://openai.com/api-priority-processing/)', then the request will be processed with the corresponding service tier. - When not set, the default behavior is 'auto'. When the `service_tier` parameter is set, the response body will include the `service_tier` value based on the processing mode actually used to serve the request. This response value may be different from the value set in the parameter. stop: Not supported with latest reasoning models `o3` and `o4-mini`. Up to 4 sequences where the API will stop generating further tokens. The returned text will not contain the stop sequence. store: Whether or not to store the output of this chat completion request for use in our [model distillation](https://platform.openai.com/docs/guides/distillation) or [evals](https://platform.openai.com/docs/guides/evals) products. Supports text and image inputs. Note: image inputs over 8MB will be dropped. stream: If set to true, the model response data will be streamed to the client as it is generated using [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format). See the [Streaming section below](https://platform.openai.com/docs/api-reference/chat/streaming) for more information, along with the [streaming responses](https://platform.openai.com/docs/guides/streaming-responses) guide for more information on how to handle the streaming events. stream_options: Options for streaming response. Only set this when you set `stream: true`. temperature: What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. We generally recommend altering this or `top_p` but not both. tool_choice: Controls which (if any) tool is called by the model. `none` means the model will not call any tool and instead generates a message. `auto` means the model can pick between generating a message or calling one or more tools. `required` means the model must call one or more tools. Specifying a particular tool via `{"type": "function", "function": {"name": "my_function"}}` forces the model to call that tool. `none` is the default when no tools are present. `auto` is the default if tools are present. tools: A list of tools the model may call. You can provide either [custom tools](https://platform.openai.com/docs/guides/function-calling#custom-tools) or [function tools](https://platform.openai.com/docs/guides/function-calling). top_logprobs: An integer between 0 and 20 specifying the number of most likely tokens to return at each token position, each with an associated log probability. `logprobs` must be set to `true` if this parameter is used. top_p: An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. We generally recommend altering this or `temperature` but not both. user: This field is being replaced by `safety_identifier` and `prompt_cache_key`. Use `prompt_cache_key` instead to maintain caching optimizations. A stable identifier for your end-users. Used to boost cache hit rates by better bucketing similar requests and to help OpenAI detect and prevent abuse. [Learn more](https://platform.openai.com/docs/guides/safety-best-practices#safety-identifiers). verbosity: Constrains the verbosity of the model's response. Lower values will result in more concise responses, while higher values will result in more verbose responses. Currently supported values are `low`, `medium`, and `high`. web_search_options: This tool searches the web for relevant results to use in a response. Learn more about the [web search tool](https://platform.openai.com/docs/guides/tools-web-search?api-mode=chat). extra_headers: Send extra headers extra_query: Add additional query parameters to the request extra_body: Add additional JSON properties to the request timeout: Override the client-level default timeout for this request, in seconds """ # noqa: E501 # TODO: handle `n > 1`` that returns multiple completions if isinstance(n, int) and n > 1: raise ValueError("n > 1 is not supported for async chat completions.") response = await self._aclient.chat.completions.create( messages=messages, model=model, audio=audio, frequency_penalty=frequency_penalty, function_call=function_call, functions=functions, logit_bias=logit_bias, logprobs=logprobs, max_completion_tokens=max_completion_tokens, max_tokens=max_tokens, metadata=metadata, modalities=modalities, n=n, parallel_tool_calls=parallel_tool_calls, prediction=prediction, presence_penalty=presence_penalty, prompt_cache_key=prompt_cache_key, reasoning_effort=reasoning_effort, response_format=response_format, safety_identifier=safety_identifier, seed=seed, service_tier=service_tier, stop=stop, store=store, stream=stream, stream_options=stream_options, temperature=temperature, tool_choice=tool_choice, tools=tools, top_logprobs=top_logprobs, top_p=top_p, user=user, verbosity=verbosity, web_search_options=web_search_options, extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout, ) # Update usage statistics if hasattr(response, "usage") and response.usage: self.update_usage(response.usage, model) return response.choices[0].message.content
# TODO: Finalize implementation
[docs] async def aembed(self, input: List[str]) -> List[List[float]]: """ Asynchronous method to generate embeddings for the given input texts. Args: input (List[str]): A list of input strings for which embeddings should be generated. Returns: List[List[float]]: A list of embeddings, where each embedding is represented as a list of floats corresponding to the input strings. """ raise NotImplementedError
[docs] def update_usage(self, usage: CompletionUsage, model_name: str = None) -> None: """Updates the internal usage counters with values from a new API response. Args: usage (CompletionUsage): The usage object returned by the OpenAI API. model_name (str, optional): The name of the model for which the usage is being updated. If None, cost is copied from usage if available. """ if not hasattr(self, "_usage"): self._usage.update(copy.deepcopy(INITIAL_USAGE)) # Ensure a fresh copy # Ensure we convert CompletionUsage to dict properly if isinstance(usage, CompletionUsage): usage = usage.model_dump() # Update core token usage fields self._usage["completion_tokens"] += usage.get("completion_tokens", 0) self._usage["prompt_tokens"] += usage.get("prompt_tokens", 0) self._usage["total_tokens"] += usage.get("total_tokens", 0) # Update prompt tokens details prompt_tokens_details = usage.get("prompt_tokens_details", {}) self._usage["prompt_tokens_details"][ "cached_tokens" ] += prompt_tokens_details.get("cached_tokens", 0) self._usage["prompt_tokens_details"][ "audio_tokens" ] += prompt_tokens_details.get("audio_tokens", 0) # Update completion tokens details completion_tokens_details = usage.get("completion_tokens_details", {}) self._usage["completion_tokens_details"][ "reasoning_tokens" ] += completion_tokens_details.get("reasoning_tokens", 0) self._usage["completion_tokens_details"][ "audio_tokens" ] += completion_tokens_details.get("audio_tokens", 0) self._usage["completion_tokens_details"][ "accepted_prediction_tokens" ] += completion_tokens_details.get("accepted_prediction_tokens", 0) self._usage["completion_tokens_details"][ "rejected_prediction_tokens" ] += completion_tokens_details.get("rejected_prediction_tokens", 0) # Update cost if model_name is not None: pricing = _get_pricing_for_model(self.provider, model_name) cost = _calculate_cost(usage, pricing) self._usage["cost"]["amount"] += cost else: # If cost is present in usage, copy it directly if "cost" in usage and "amount" in usage["cost"]: self._usage["cost"]["amount"] = usage["cost"]["amount"]
def _validate_config_param(name, value): # Accept basic JSON types if isinstance(value, (str, int, float, bool, type(None))): return # Recursively check mappings (dict, Mapping) if isinstance(value, Mapping): for k, v in value.items(): if not isinstance(k, str): raise TypeError( f"Config parameter '{name}' " f"has a non-string key '{k}' of type {type(k).__name__}." ) _validate_config_param(f"{name}.{k}", v) return # Recursively check sequences (list, tuple, set) but not str/bytes if isinstance(value, (list, tuple, set)): for idx, item in enumerate(value): _validate_config_param(f"{name}[{idx}]", item) return # Explicitly reject bytes and bytearray if isinstance(value, (bytes, bytearray)): raise TypeError( f"Config parameter '{name}' " f"is of type {type(value).__name__}, which is not JSON serializable." ) # Fallback: try json serialization try: json.dumps(value) except Exception: raise TypeError( f"Config parameter '{name}' " f"with value '{value}' of type {type(value).__name__} is not serializable." ) def _get_pricing_for_model(provider: str, model_name: str) -> dict: # Load pricing data (cache this in production) prices_path = os.path.join(os.path.dirname(__file__), "model_prices.json") with open(prices_path, "r") as f: prices = json.load(f) provider_data = prices.get(provider, {}) models_map = provider_data.get("models", {}) pricing_map = provider_data.get("pricing", {}) pricing_key = models_map.get(model_name, model_name) return pricing_map.get(pricing_key, {}) def _calculate_cost(usage: dict, pricing: dict) -> float: input_tokens = usage.get("prompt_tokens", 0) cached_tokens = usage.get("prompt_tokens_details", {}).get("cached_tokens", 0) uncached_tokens = input_tokens - cached_tokens output_tokens = usage.get("completion_tokens", 0) cost = 0.0 if "input" in pricing and pricing["input"] is not None: cost += (uncached_tokens * pricing["input"]) / 1_000_000 if "cached" in pricing and pricing["cached"] is not None: cost += (cached_tokens * pricing["cached"]) / 1_000_000 if "output" in pricing and pricing["output"] is not None: cost += (output_tokens * pricing["output"]) / 1_000_000 return cost