afnio.cognitive.modules.chat_completion#
Classes
Generates a chat-based completion using a language model. |
- class afnio.cognitive.modules.chat_completion.ChatCompletion[source]#
Bases:
ModuleGenerates a chat-based completion using a language model.
This module leverages the ChatCompletion operation from afnio.autodiff.lm_ops to perform model inference. The forward method accepts a list of messages representing the conversation history, with optional dynamic inputs for filling placeholders within the messages. The forward_model_client is responsible for interfacing with the language model (e.g., GPT), while completion_args allows customization of generation parameters such as temperature, maximum tokens, and seed.
Example
>>> import afnio as hf >>> from afnio import cognitive as cog >>> from afnio.models.openai import OpenAI >>> from afnio import set_backward_model_client >>> fwd_model_client = OpenAI() >>> fwd_model_args = {"model": "gpt-4o", "temperature": 0.7} >>> set_backward_model_client("openai/gpt-4o") >>> class Assistant(cog.Module): ... def __init__(self): ... super().__init__() ... self.chat = cog.ChatCompletion() ... def forward(self, fwd_model, messages, inputs, **completion_args): ... return self.chat(fwd_model, messages, inputs, **completion_args) >>> system = Variable( ... "You are a helpful assistant.", ... role="system instruction", ... requires_grad=True ... ) >>> user = Variable("Translate 'Hello' to {language}.", role="user query") >>> language = hf.Variable("Italian", role="language") >>> messages = [ ... {"role": "system", "content": [system]}, ... {"role": "user", "content": [user]}, ... ] >>> model = Assistant() >>> response = model( ... fwd_model_client, ... messages, ... inputs={"language": language}, ... **fwd_model_args ... ) >>> print(response.data) 'Ciao' >>> feedback = Variable("Use only capital letters.", role="feedback") >>> response.backward(feedback) >>> system.grad[0].data 'The system instruction should enforce the use of capital letters only.'
See also
afnio.autodiff.lm_ops.ChatCompletionfor the underlying operation.- T_destination = ~T_destination#
- automatic_optimization: bool#
- buffers(recurse=True)#
Return an iterator over module buffers.
- Parameters:
recurse (bool) – if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module.
- Yields:
hf.Variable – module buffer
Example:
>>> for buf in model.buffers(): >>> print(type(buf), buf.data) <class 'afnio.Variable'> ("Structure your answer as JSON.") <class 'afnio.Variable'> ("Use the format\n\n{\n \"response\": \"Your concise answer here.\"\n}")
- chats(recurse=True)#
Return an iterator over module multi-turn chats.
This is typically passed to an optimizer.
- Parameters:
recurse (bool) – if True, then yields chats of this module and all submodules. Otherwise, yields only chats that are direct members of this module.
- Yields:
MultiTurnMessages – module chats
Example:
>>> for chat in pipeline.chats(): >>> print(type(chat), chat) <class 'cog.MultiTurnMessages'> [{'role': 'system', 'content': [Variable(data=You are a doctor., role=system instruction, requires_grad=False)]}, {'role': 'user', 'content': [Variable(data=Is {item} a disease?, role=user query, requires_grad=False)]}] <class 'cog.MultiTurnMessages'> [{'role': 'system', 'content': [Variable(data=You are a helpful assistant., role=system instruction, requires_grad=False), Variable(data=Only answer with YES or NO., role=user query, requires_grad=False)]}]
- children()#
Return an iterator over immediate children modules.
- Yields:
Module – a child module
- completion_configs(recurse=True)#
Return an iterator over registered completion configs.
- Parameters:
recurse (bool) – if True, then yields completion configs of this module and all submodules. Otherwise, yields only completion configs that are direct members of this module.
- Yields:
dict – completion arguments
- Example::
>>> for config in model.completion_configs(): >>> print(config) {"model": "gpt-4o", "seed": 42, "temperature": 0}
- configure_optimizers()#
Configure and return the optimizer for this module.
This method should be implemented in subclasses to define the optimizer configuration. It is called by the
Trainerto set up the optimization routine.- Returns:
An instance of an optimizer configured for this module.
- Return type:
- Raises:
NotImplementedError – If not implemented in a subclass.
- empty_grad()#
Reset gradients of all model parameters and content variables in chats’ messages.
This method is useful for clearing out gradients before starting a new optimization step. It ensures that both module parameters and Variables within multi-turn chat’s message contents have their gradients reset, avoiding unintended gradient accumulation.
- eval()#
Set the module in evaluation mode.
This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected.
This is equivalent with
self.train(False).- Returns:
self
- Return type:
- abstractmethod extra_repr()#
Set the extra representation of the module.
To print customized extra information, you should re-implement this method in your own modules. Both single-line and multi-line strings are acceptable.
- forward(forward_model_client, messages, inputs=None, **completion_args)[source]#
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
One should invoce the
Moduleinstance (Module.__call__ method) instead of directly calling Module.forward(). This way hooks are registered and run.
-
forward_model_client:
Optional[ChatCompletionModel]#
- functions(recurse=True)#
Return an iterator over registered functions.
- Parameters:
recurse (bool) – if True, then yields functions of this module and all submodules. Otherwise, yields only functions that are direct members of this module.
- Yields:
Callable – functions
- Example::
>>> for func in model.functions(): >>> print(func) <built-in function sum> <function my_func at 0x7e7a0665b9c0>
- get_extra_state()#
Return any extra state to include in the module’s state_dict.
Implement this and a corresponding
set_extra_state()for your module if you need to store extra state. This function is called when building the module’s state_dict().Note that extra state should be picklable to ensure working serialization of the state_dict.
- Returns:
Any extra state to store in the module’s state_dict.
- Return type:
- load_state_dict(state_dict, strict=True, assign=False, model_clients=None)#
Copy parameters, buffers, chats, models, completion configs and functions from
state_dictinto this module and its descendants.If
strictisTrue, then the keys ofstate_dictmust exactly match the keys returned by this module’sstate_dict()function.Warning
If
assignisTruethe optimizer must be created after the call toload_state_dict.- Parameters:
state_dict (dict) – A dict containing parameters, persistent buffers, chats, models, completion configs and functions.
strict (bool, optional) – Whether to strictly enforce that the keys in
state_dictmatch the keys returned by this module’sstate_dict()function. Default:Trueassign (bool, optional) – When
False, the properties of the Variables in the current module are preserved while whenTrue, the properties of the Variables in the state dict are preserved. The only exception is therequires_gradfield ofDefault: ``False`model_clients (dict, optional) – A dictionary mapping model client keys (e.g., ‘fw_model_client’) to their respective instances of
BaseModel. These instances will be used to reconstruct any model clients referenced within the optimizer state. If a required model client is missing, an error will be raised with instructions on how to provide the missing client.
- Returns:
- missing_keys is a list of str containing any keys that are
expected by this module but missing from the provided
state_dict.
- unexpected_keys is a list of str containing the keys that are not
expected by this module but present in the provided
state_dict.
- Return type:
NamedTuplewithmissing_keysandunexpected_keysfields
Note
If a parameter, or buffer, or chat, or model, or completion config, or function is registered as
Noneand its corresponding key exists instate_dict,load_state_dict()will raise aRuntimeError.
- models(recurse=True)#
Return an iterator over module language model clients.
- Parameters:
recurse (bool) – if True, then yields models of this module and all submodules. Otherwise, yields only models that are direct members of this module.
- Yields:
BaseModel – module model
Example:
>>> for model in pipeline.models(): >>> print(type(model)) <class 'afnio.models.openai.AsyncOpenAI'>
- modules()#
Return an iterator over all modules in the network.
- Yields:
Module – a module in the network
Note
Duplicate modules are returned only once. In the following example,
addwill be returned only once.Example:
>>> class MyPipeline(cog.Module): ... def __init__(self): ... super().__init__() ... add = cog.Add() ... self.module1 = add ... self.module2 = add >>> def forward(self, x, y): ... out1 = self.module1(x, x) ... out2 = self.module2(x, y) ... return out1 + out2 >>> pipeline = MyPipeline() >>> for idx, m in enumerate(model.modules()): ... print(idx, '->', m) 0 -> MyModel( (module1): Module() (module2): Module() ) 1 -> Module()
- named_buffers(prefix='', recurse=True, remove_duplicate=True)#
Return an iterator over module buffers, yielding both the name of the buffer as well as the buffer itself.
- Parameters:
prefix (str) – prefix to prepend to all buffer names.
recurse (bool, optional) – if True, then yields buffers of this module and all submodules. Otherwise, yields only buffers that are direct members of this module. Defaults to True.
remove_duplicate (bool, optional) – whether to remove the duplicated buffers in the result. Defaults to True.
- Yields:
(str, hf.Variable) – Tuple containing the name and buffer
Example:
>>> for name, buf in self.named_buffers(): >>> if "format_type" in name: >>> print(param.data)
- named_chats(prefix='', recurse=True, remove_duplicate=True)#
Return an iterator over module multi-turn chats, yielding both the name of chat as well as the chat itself.
- Parameters:
prefix (str) – prefix to prepend to all chat names.
recurse (bool) – if True, then yields chats of this module and all submodules. Otherwise, yields only chats that are direct members of this module.
remove_duplicate (bool, optional) – whether to remove the duplicated chats in the result. Defaults to True.
- Yields:
(str, MultiTurnMessages) – Tuple containing the name and chat
Example:
>>> for name, chat in self.named_chats(): >>> if "messages" in name: >>> print(messages[0]["role"])
- named_children()#
Return an iterator over immediate children modules, yielding both the name of the module as well as the module itself.
- Yields:
(str, Module) – Tuple containing a name and child module
- named_completion_configs(prefix='', recurse=True, remove_duplicate=True)#
Return an iterator over module completion configs, yielding both the name of the completion config as well as the completion config itself.
- Parameters:
prefix (str) – prefix to prepend to all completion config names.
recurse (bool) – if True, then yields completion configs of this module and all submodules. Otherwise, yields only completion configs that are direct members of this module.
remove_duplicate (bool, optional) – whether to remove the duplicated completion configs in the result. Defaults to True.
- Yields:
(str, dict) – Tuple containing the name and completion configs
Example:
>>> for name, config in self.named_completion_configs(): >>> print(name, type(config)) chat.completion_args {'model': 'gpt-4o', 'seed': 42, 'temperature': 0}
- named_functions(prefix='', recurse=True, remove_duplicate=True)#
Return an iterator over module functions, yielding both the name of the function as well as the function itself.
- Parameters:
prefix (str) – prefix to prepend to all function names.
recurse (bool) – if True, then yields functions of this module and all submodules. Otherwise, yields only functions that are direct members of this module.
remove_duplicate (bool, optional) – whether to remove the duplicated functions in the result. Defaults to True.
- Yields:
(str, Callable) – Tuple containing the name and functions
Example:
>>> for name, func in self.named_functions(): >>> print(name, func) reduction_fn <built-in function sum> eval_fn <function my_func at 0x7e7a0665b9c0>
- named_models(prefix='', recurse=True, remove_duplicate=True)#
Return an iterator over module model clients, yielding both the name of the model as well as the model itself.
- Parameters:
prefix (str) – prefix to prepend to all model names.
recurse (bool) – if True, then yields models of this module and all submodules. Otherwise, yields only models that are direct members of this module.
remove_duplicate (bool, optional) – whether to remove the duplicated models in the result. Defaults to True.
- Yields:
(str, BaseModel) – Tuple containing the name and model
Example:
>>> for name, model in self.named_models(): >>> print(name, type(model)) model_client <class 'afnio.models.openai.AsyncOpenAI'>
- named_modules(memo=None, prefix='', remove_duplicate=True)#
Return an iterator over all modules in the network, yielding both the name of the module as well as the module itself.
- Parameters:
- Yields:
(str, Module) – Tuple of name and module
Note
Duplicate modules are returned only once. In the following example,
addwill be returned only once.Example:
>>> class MyPipeline(cog.Module): ... def __init__(self): ... super().__init__() ... add = cog.Add() ... self.module1 = add ... self.module2 = add >>> def forward(self, x, y): ... out1 = self.module1(x, x) ... out2 = self.module2(x, y) ... return out1 + out2 >>> pipeline = MyPipeline() >>> for idx, m in enumerate(model.named_modules()): ... print(idx, '->', m) 0 -> ('', MyModel( (module1): Module() (module2): Module() )) 1 -> ('module1', Module())
Example:
>>> class MyPipeline(cog.Module): ... def __init__(self): ... super().__init__() ... add = cog.Add() ... self.module1 = add ... self.module2 = add >>> def forward(self, x, y): ... out1 = self.module1(x, x) ... out2 = self.module2(x, y) ... return out1 + out2 >>> pipeline = MyPipeline() >>> for idx, m in enumerate(model.named_modules(remove_duplicate=False)): ... print(idx, '->', m) 0 -> ('', MyModel( (module1): Module() (module2): Module() )) 1 -> ('module1', Module()) 2 -> ('module2', Module())
- named_parameters(prefix='', recurse=True, remove_duplicate=True)#
Return an iterator over module parameters, yielding both the name of the parameter as well as the parameter itself.
- Parameters:
prefix (str) – prefix to prepend to all parameter names.
recurse (bool) – if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module.
remove_duplicate (bool, optional) – whether to remove the duplicated parameters in the result. Defaults to True.
- Yields:
(str, Parameter) – Tuple containing the name and parameter
Example:
>>> for name, param in self.named_parameters(): >>> if "prompt" in name: >>> print(param.data)
- optimizers()#
Returns the optimizer(s) that are being used during training. Useful for manual optimization.
This method is useful for accessing the optimizer(s) configured in the
configure_optimizers()method by thefit()method.Example:
>>> optimizers = model.optimizers() >>> for optimizer in optimizers: >>> print(optimizer) TGD ( Parameter Group 0 completion_args: {'model': 'gpt-4.1'} constraints: [] inputs: {} messages: [ {'role': 'system', 'content': [Variable(data="Placeholder Textual Gradient Descent optimizer system prompt", role=Textual Gradient Descent optimizer system prompt, requires_grad=False)]}, {'role': 'user', 'content': [Variable(data="Placeholder for Textual Gradient Descent optimizer user prompt", role=Textual Gradient Descent optimizer user prompt, requires_grad=False)]} ] model_client: <afnio.models.openai.AsyncOpenAI object at 0x710df9c149a0> momentum: 3 )
- parameters(recurse=True)#
Return an iterator over module parameters.
This is typically passed to an optimizer.
- Parameters:
recurse (bool) – if True, then yields parameters of this module and all submodules. Otherwise, yields only parameters that are direct members of this module.
- Yields:
Parameter – module parameter
Example:
>>> for param in pipeline.parameters(): >>> print(type(param), param.data) <class 'cog.Parameter'> ("You are a doctor.") <class 'cog.Parameter'> ("Only answer with YES or NO.")
- register_buffer(name, variable, persistent=True)#
Add a buffer to the module.
This is typically used to register a buffer that should not to be considered a model parameter. For example, Prompt’s
format_typeis not a parameter, but is part of the module’s state. Buffers, by default, are persistent and will be saved alongside parameters. This behavior can be changed by settingpersistenttoFalse. The only difference between a persistent buffer and a non-persistent buffer is that the latter will not be a part of this module’sstate_dict.Buffers can be accessed as attributes using given names.
- Parameters:
name (str) – Name of the buffer. The buffer can be accessed from this module using the given name.
variable (Variable or None) – Buffer to be registered. If
None, then operations that run on buffers are ignored. IfNone, the buffer is not included in the module’sstate_dict.persistent (bool) – Whether the buffer is part of this module’s
state_dict.
- Example::
>>> self.register_buffer('format_type', hf.Variable(data="Structure your answer as JSON.", role="format type"))
- register_chat(name, messages)#
Add multi-turn chat messages to the module.
The chat can be accessed as an attribute using given name.
- Parameters:
name (str) – Name of the chat. The chat can be accessed from this module using the given name.
messages (MultiTurnMessages or None) – Chat to be added to the module. If
None, then operations that run on chats are ignored. IfNone, the chat is not included in the module’sstate_dict.
- register_completion_config(name, args)#
Register completion-specific arguments for text generation.
This method allows dynamic storage of completion-related parameters such as temperature, max_tokens, top_p, etc.
- Parameters:
name (str) – Name of the completion argument set.
args (dict or None) – Dictionary of completion arguments. If
None, the argument is not included in the module’sstate_dict.
- register_function(name, func)#
Add a function to the module.
The function can be accessed as an attribute using given name.
- Parameters:
name (str) – Name of the function. The function can be accessed from this module using the given name.
func (FunctionType or None) – A standard Python function (i.e., a def-defined function, not a lambda or callable object) that can be pickled and registered for later execution. If None, the function is unregistered. If
None, the function is not included in the module’sstate_dict.
- register_model(name, model)#
Add language model the module.
The language model can be accessed as an attribute using given name.
- Parameters:
name (str) – Name of the model. The model can be accessed from this module using the given name.
model (BaseModel or None) – Model to be added to the module. If
None, then operations that run on models are ignored. IfNone, the model is not included in the module’sstate_dict.
- register_module(name, module)#
Add a child module to the current module.
This method explicitly adds a child module to the current module’s hierarchy. The child module can then be accessed as an attribute using the given name and will be registered in the _modules dictionary.
When to use: - Use register_module() when dynamically adding submodules at runtime, especially when the submodule name is determined programmatically. - This can be useful for creating flexible and modular architectures.
When it’s unnecessary: - Directly assigning the module to an attribute (e.g., self.module_name = SubModule()) automatically registers it, so using register_module() is unnecessary in such cases.
- Parameters:
- Raises:
- Example::
>>> class DynamicPipeline(cog.Module): >>> def __init__(self): >>> super().__init__() >>> # Dynamically add submodules >>> for i in range(3): >>> self.register_module(f"layer_{i}", cog.Module())
>>> pipeline = DynamicPipeline() >>> print(pipeline._modules.keys()) odict_keys(['layer_0', 'layer_1', 'layer_2'])
Note
If assigning submodules using standard attribute assignment (e.g., self.submodule = SubModule()), calling register_module() explicitly is not required. Direct assignment automatically registers the module.
- register_parameter(name, param)#
Add a parameter to the module.
The parameter can be accessed as an attribute using given name.
- Parameters:
name (str) – Name of the parameter. The parameter can be accessed from this module using the given name.
param (Parameter or None) – Parameter to be added to the module. If
None, then operations that run on parameters are ignored. IfNone, the parameter is not included in the module’sstate_dict.
- requires_grad_(requires_grad=True)#
Change if autodiff should record operations on parameters and chats in this module.
This method sets the
requires_gradattributes of all module parameters in-place. It also sets therequires_gradattributes of all the Variables within the content of multi-turn chats.- Effect on Parameters:
Sets
requires_gradfor each registered parameter in the module.
- Effect on Chats:
Iterates through all multi-turn chats and sets
requires_grad
for each Variable in the “content” key of the chat’s message.
This method is helpful for freezing part of the module for finetuning or training parts of a model individually.
- set_extra_state(state)#
Set extra state contained in the loaded state_dict.
This function is called from
load_state_dict()to handle any extra state found within the state_dict. Implement this function and a correspondingget_extra_state()for your module if you need to store extra state within its state_dict.- Parameters:
state (dict) – Extra state from the state_dict.
- state_dict(*, destination=None, prefix='', keep_vars=False)#
Return a dictionary containing references to the whole state of the module.
Parameters, persistent buffers (e.g. running averages), multi-turn chats, models, completion configs and functions are included. Keys are corresponding parameter, buffer, chat, model, completion config and function names. Parameters, buffers, chats, models, completion configs and functions set to
Noneare not included.Note
The returned object is a shallow copy. It contains references to the module’s parameters, buffers, chats, models, completion configs and functions.
Warning
Please avoid the use of argument
destinationas it is not designed for end-users.- Parameters:
destination (dict, optional) – If provided, the state of module will be updated into the dict and the same object is returned. Otherwise, an
OrderedDictwill be created and returned. Default:None.prefix (str, optional) – A prefix added to parameter, buffer, chat, model, completion config and function names to compose the keys in state_dict. Default:
''.keep_vars (bool, optional) – By default the
Variables returned in the state dict are detached from autodiff. If it’s set toTrue, detaching will not be performed. Default:False.
- Returns:
A dictionary containing a whole state of the module.
- Return type:
Example:
>>> module.state_dict().keys() ['system_prompt', 'classification_labels', 'format_type', 'user_prompt']
- test_step(batch, batch_idx)#
Perform a single test step.
This method should be implemented in subclasses to define the test logic. It is called by the
Trainerduring the testing loop.- Parameters:
- Returns:
- The loss as a tuple of two Variables:
The evaluation score (a Variable containing the loss value).
The explanation (a Variable containing a string explanation of the evaluation result).
- dict: A dictionary. Can include any keys, but must include
the key
'loss'containing a tuple of two Variables (score and explanation).
None: Skip to the next batch.
- Return type:
Tuple[Variable, Variable]
- Raises:
NotImplementedError – If not implemented in a subclass.
- train(mode=True)#
Set the module in training mode.
This has any effect only on certain modules. See documentations of particular modules for details of their behaviors in training/evaluation mode, if they are affected.
- training: bool#
- training_step(batch, batch_idx)#
Perform a single training step.
This method should be implemented in subclasses to define the training logic. It is called by the
Trainerduring the training loop.- Parameters:
- Returns:
- The loss as a tuple of two Variables:
The evaluation score (a Variable containing the loss value).
The explanation (a Variable containing a string explanation of the evaluation result).
- dict: A dictionary. Can include any keys, but must include
the key
'loss'containing a tuple of two Variables (score and explanation).
None: Skip to the next batch.
- Return type:
Tuple[Variable, Variable]
- Raises:
NotImplementedError – If not implemented in a subclass.
- validation_step(batch, batch_idx)#
Perform a single validation step.
This method should be implemented in subclasses to define the validation logic. It is called by the
Trainerduring the validation loop.- Parameters:
- Returns:
- The loss as a tuple of two Variables:
The evaluation score (a Variable containing the loss value).
The explanation (a Variable containing a string explanation of the evaluation result).
- dict: A dictionary. Can include any keys, but must include
the key
'loss'containing a tuple of two Variables (score and explanation).
None: Skip to the next batch.
- Return type:
Tuple[Variable, Variable]
- Raises:
NotImplementedError – If not implemented in a subclass.