Source code for afnio.cognitive.modules.chat_completion
from typing import Any, Dict, Optional, Union
from afnio._utils import MultiTurnMessages
from afnio._variable import Variable
from afnio.autodiff.lm_ops import ChatCompletion as ChatCompletionOp
from afnio.models import ChatCompletionModel
from .module import Module
[docs]
class ChatCompletion(Module):
"""
Generates a chat-based completion using a language model.
This module leverages the `ChatCompletion` operation from
`afnio.autodiff.lm_ops` to perform model inference. The `forward` method
accepts a list of `messages` representing the conversation history, with optional
dynamic `inputs` for filling placeholders within the messages. The
`forward_model_client` is responsible for interfacing with the language model
(e.g., GPT), while `completion_args` allows customization of generation parameters
such as temperature, maximum tokens, and seed.
Example:
>>> import afnio as hf
>>> from afnio import cognitive as cog
>>> from afnio.models.openai import OpenAI
>>> from afnio import set_backward_model_client
>>> fwd_model_client = OpenAI()
>>> fwd_model_args = {"model": "gpt-4o", "temperature": 0.7}
>>> set_backward_model_client("openai/gpt-4o")
>>> class Assistant(cog.Module):
... def __init__(self):
... super().__init__()
... self.chat = cog.ChatCompletion()
... def forward(self, fwd_model, messages, inputs, **completion_args):
... return self.chat(fwd_model, messages, inputs, **completion_args)
>>> system = Variable(
... "You are a helpful assistant.",
... role="system instruction",
... requires_grad=True
... )
>>> user = Variable("Translate 'Hello' to {language}.", role="user query")
>>> language = hf.Variable("Italian", role="language")
>>> messages = [
... {"role": "system", "content": [system]},
... {"role": "user", "content": [user]},
... ]
>>> model = Assistant()
>>> response = model(
... fwd_model_client,
... messages,
... inputs={"language": language},
... **fwd_model_args
... )
>>> print(response.data)
'Ciao'
>>> feedback = Variable("Use only capital letters.", role="feedback")
>>> response.backward(feedback)
>>> system.grad[0].data
'The system instruction should enforce the use of capital letters only.'
See Also:
:class:`afnio.autodiff.lm_ops.ChatCompletion` for the underlying operation.
"""
forward_model_client: Optional[ChatCompletionModel]
messages: MultiTurnMessages
completion_args: Dict[str, Any]
def __init__(self):
super().__init__()
self.register_model("forward_model_client", None)
self.register_chat("messages", None)
self.register_completion_config("completion_args", None)
[docs]
def forward(
self,
forward_model_client: Optional[ChatCompletionModel],
messages: MultiTurnMessages,
inputs: Optional[Dict[str, Union[str, Variable]]] = None,
**completion_args,
) -> Variable:
self.forward_model_client = forward_model_client
self.messages = messages
self.completion_args = completion_args
return ChatCompletionOp.apply(
self.forward_model_client,
self.messages,
inputs,
**self.completion_args,
)