import logging
from typing import Callable, Dict, List, Optional, Tuple, Union
from afnio._utils import (
MultiTurnMessages,
_deserialize_output,
_serialize_arg,
)
from afnio._variable import Variable
from afnio.logging_config import configure_logging
from afnio.models import ChatCompletionModel
from afnio.tellurio._client_manager import get_default_clients
from afnio.tellurio._eventloop import run_in_background_loop
from afnio.tellurio._variable_registry import (
suppress_variable_notifications,
)
from .optimizer import Optimizer, ParamsT, _extract_variable_ids, _wait_for_variable
# Configure logging
configure_logging()
logger = logging.getLogger(__name__)
# Suppress notifications for variable changes during modules initialization
# as the WebSocket connection is not established yet
with suppress_variable_notifications():
# TGD_MESSAGES is only a placeholder that will be replaced on the server side by the
# actual system prompt and user instruction
TGD_MESSAGES = [
{
"role": "system",
"content": [
Variable(
data="Placeholder for Textual Gradient Descent optimizer system prompt", # noqa: E501
role="Textual Gradient Descent optimizer system prompt",
)
],
},
{
"role": "user",
"content": [
Variable(
data="Placeholder for Textual Gradient Descent optimizer user prompt", # noqa: E501
role="Textual Gradient Descent optimizer user prompt",
)
],
},
]
[docs]
class TGD(Optimizer):
def __init__(
self,
params: ParamsT,
model_client: Optional[ChatCompletionModel],
messages: MultiTurnMessages = TGD_MESSAGES,
inputs: Optional[Dict[str, Union[str, Variable]]] = None,
constraints: Optional[List[Union[str, Variable]]] = None,
momentum: int = 0,
**completion_args,
):
"""
Textual Gradient Descent (TGD) optimizer.
Args:
params (ParamsT): Iterable of parameters to optimize or dicts defining
parameter groups.
model_client (Optional[ChatCompletionModel]): LM model client used
for optimization.
messages (MultiTurnMessages): Messages for multi-turn interactions. It
typically defines the optimizer system prompt and user instruction.
In-context examples (shots) can be added as well.
inputs (Optional[Dict[str, Union[str, Variable]]]): Dynamic values to fill
placeholders within message templates
constraints (Optional[List[Union[str, Variable]]]): A list of natural
language constraints for optimization.
momentum (int, optional): Momentum window size. Tracks the last `momentum`
gradients, which helps accelerate updates in the right direction and
dampen oscillations. Defaults to 0.
completion_args (Dict[str, Any], optional): Additional arguments to pass to
the model client when generating text completions. Defaults to an
empty dictionary.
"""
# Workaround to trigger TGD_MESSAGES registration with the server
# and store related variable_ids on the client side
if messages is TGD_MESSAGES:
messages = [
{
"role": "system",
"content": [
Variable(
data="Placeholder for Textual Gradient Descent optimizer system prompt", # noqa: E501
role="Textual Gradient Descent optimizer system prompt",
)
],
},
{
"role": "user",
"content": [
Variable(
data="Placeholder for Textual Gradient Descent optimizer user prompt", # noqa: E501
role="Textual Gradient Descent optimizer user prompt",
)
],
},
]
defaults = dict(
model_client=model_client,
messages=messages,
inputs=inputs or {},
constraints=constraints or [],
momentum=momentum,
completion_args=completion_args,
)
super().__init__(params, defaults)
[docs]
def step(
self, closure: Optional[Callable] = None
) -> Optional[Tuple[Variable, Variable]]:
"""Performs a single optimization step.
Args:
closure (Optional[Callable]): A closure that reevaluates the model
and returns the loss.
Returns:
Optional[Tuple[Variable, Variable]]: The loss if `closure` is provided,
otherwise None. The loss should return a numerical or textual score and
a textual explanation, both wrapped as `Variable` objects
"""
loss = closure() if closure else (None, None)
super().step()
return loss
def _extract_variable_ids_from_state(self, state):
"""
Extract only the variable_ids of deepcopied parameters (i.e., those generated
on the server) from the optimizer state.
Args:
state (list): The serialized optimizer state as returned by the server.
Returns:
Set[str]: Set of variable_ids for deepcopied parameters.
"""
var_ids = set()
for entry in state:
momentum_buffer = entry.get("value", {}).get("momentum_buffer", [])
for buf_entry in momentum_buffer:
if (
isinstance(buf_entry, list)
and len(buf_entry) > 0
and isinstance(buf_entry[0], dict)
and "variable_id" in buf_entry[0]
):
var_ids.add(buf_entry[0]["variable_id"])
return var_ids
# TODO: Fix error when passing str constraints
[docs]
def tgd(
params: List[Variable],
grads: List[List[Variable]],
momentum_buffer_list: List[Optional[List[Variable]]],
model_client: Optional[ChatCompletionModel],
messages: MultiTurnMessages,
inputs: Optional[Dict[str, Union[str, Variable]]],
constraints: Optional[List[Union[str, Variable]]],
momentum: int,
**completion_args,
):
r"""Functional API that performs TGD (Textual Gradient Descent) algorithm
computation.
See :class:`~afnio.optim.SGD` for details.
"""
# Set `_pending_data` for all parameters that will be optimized
for p in params:
p._pending_data = True
logger.debug(f"Marked variable {p.variable_id!r} as pending for data update.")
try:
_, ws_client = get_default_clients()
payload = {
"params": _serialize_arg(params),
"grads": _serialize_arg(grads),
"momentum_buffer_list": _serialize_arg(momentum_buffer_list),
"model_client": _serialize_arg(model_client),
"messages": _serialize_arg(messages),
"inputs": _serialize_arg(inputs),
"constraints": _serialize_arg(constraints),
"momentum": momentum,
"completion_args": _serialize_arg(completion_args),
}
response = run_in_background_loop(ws_client.call("run_optimizer_tgd", payload))
if "error" in response:
raise RuntimeError(
response["error"]["data"].get("exception", response["error"])
)
logger.debug(f"TGD optimization request sent: {payload!r}")
result = response.get("result", {})
result_message = result.get("message")
result_momentum_buffer_list = result.get("momentum_buffer_list", [])
# Extract all variable_ids from the result_momentum_buffer_list
# and wait for them to be registered in VARIABLE_REGISTRY
all_var_ids = _extract_variable_ids(result_momentum_buffer_list)
for var_id in all_var_ids:
_wait_for_variable(var_id)
des_momentum_buffer_list = _deserialize_output(result_momentum_buffer_list)
# Convert [param, grads] lists to (param, grads) tuples
for i, buffer in enumerate(des_momentum_buffer_list):
des_momentum_buffer_list[i] = [
tuple(pair) if isinstance(pair, list) and len(pair) == 2 else pair
for pair in buffer
]
if result_message != "Functional TGD optimization step executed successfully.":
raise RuntimeError(
f"Server did not return any data for functional TGD optimization: "
f"payload={payload!r}, response={response!r}"
)
# Update the momentum_buffer_list with the deserialized buffers
momentum_buffer_list.clear()
momentum_buffer_list.extend(des_momentum_buffer_list)
logger.debug("Functional TGD optimization executed successfully")
except Exception as e:
logger.error(f"Failed to run functional TGD optimization on the server: {e!r}")
# Clear all pending data flags to avoid deadlocks
for p in params:
p._pending_data = False
logger.debug(
f"Marked variable {p.variable_id!r} as not pending for data update "
f"after error."
)
raise