Source code for afnio._variable

import inspect
import logging
import threading
import time
from contextlib import contextmanager
from typing import TYPE_CHECKING, Callable, List, Optional, Union

from afnio.logging_config import configure_logging
from afnio.tellurio._client_manager import get_default_clients
from afnio.tellurio._eventloop import run_in_background_loop
from afnio.tellurio._variable_registry import (
    is_variable_notify_suppressed,
    register_variable,
)

# Import `Node` only for type hints to avoid runtime circular imports; `TYPE_CHECKING`
# ensures it's available for static analysis (e.g., mypy) without executing at runtime.
if TYPE_CHECKING:
    from afnio.autodiff.graph import Node

from copy import deepcopy

import afnio

# Thread-local flag to control the assignment of `grad_fn` attributes
_grad_fn_assignment_allowed = threading.local()
_grad_fn_assignment_allowed.value = False

# Configure logging
configure_logging()
logger = logging.getLogger(__name__)


[docs] class Variable: """ A class to represent generic data, such as textual inputs, outputs, or numeric data. Attributes: data (str | int | float | List[Union[str, int, float]]): The raw data, which can be a single string or numeric value or a list of single string or numeric values. requires_grad (bool): Whether to track operations for automatic differentiation. role (str): A specific description of the role of the variable in the model. grad (Optional[float]): Stores the gradient of the variable, if `requires_grad` is set to True and backpropagation has been performed. """ # Using forward references for `Variable` and `Node` defined later requires_grad: bool _grad: List["Variable"] # TODO: Consider having `VariableMeta` class with `.grad_fn` and `.output_nr` # as attributes _output_nr: Optional[int] _grad_fn: Optional["Node"] _retain_grad: bool is_leaf: bool r"""All Variables that have :attr:`requires_grad` which is ``False`` will be leaf Variables by convention. For Variables that have :attr:`requires_grad` which is ``True``, they will be leaf Variables if they were created by the user. This means that they are not the result of an operation and so :attr:`grad_fn` is None. Only leaf Variables will have their :attr:`grad` populated during a call to :func:`backward`. To get :attr:`grad` populated for non-leaf Variables, you can use :func:`retain_grad`. Example:: >>> a = hf.Variable("abc", requires_grad=True) >>> a.is_leaf True >>> b = hf.Variable("abc", requires_grad=True).upper() >>> b.is_leaf False # b was created by the operation that converts all string characters to uppercase >>> c = hf.Variable("abc", requires_grad=True) + "def" >>> c.is_leaf False # c was created by the addition operation >>> d = hf.Variable("abc").upper() >>> d.is_leaf True # d does not require gradients and so has no operation creating it (that is tracked by the autodiff engine) >>> e = hf.Variable("abc").upper().requires_grad_() >>> e.is_leaf True # e requires gradients and has no operations creating it """ # noqa: E501 variable_id: Optional[str] _initialized: bool _pending_grad_fn_id: Optional[str] _pending_grad: Optional[bool] _pending_data: Optional[bool] def __init__( self, data: Optional[Union[str, int, float, List[Union[str, int, float]]]] = "", role: str = "", requires_grad: bool = False, ): if not isinstance(data, (str, int, float, list, tuple)): raise TypeError( "`data` must be a single value (str, int, float) or a list/tuple of " "such values." ) if isinstance(data, (list, tuple)): # Check if the list/tuple is homogeneous (all strings or all numbers) all_strings = all(isinstance(d, str) for d in data) all_numbers = all(isinstance(d, (int, float)) for d in data) if not (all_strings or all_numbers): raise TypeError( f"When `data` is a {type(data).__name__}, it must be either " f"all strings or all numbers (int, float)." ) if all_numbers: # Check for mixed int and float types contains_int = any(isinstance(d, int) for d in data) contains_float = any(isinstance(d, float) for d in data) if contains_int and contains_float: data = [float(d) for d in data] if isinstance(data, tuple): data = list(data) # Websocket attributes self.variable_id = None self._initialized = False # Falgs variable is ready to send websocket updates self._pending_grad_fn_id = None # Flags grad_fn is being set (fwd pass running) self._pending_grad = False # Flags grad is being set (bwd pass running) self._pending_data = False # Flags data is being set (optim step running) # Internal attributes self._data = data self.role = role self.requires_grad = requires_grad self._retain_grad = False self._grad = [] self._output_nr = 0 self._grad_fn = None self.is_leaf = not requires_grad or self.grad_fn is None # Share the variable with the websocket server if not is_variable_notify_suppressed(): try: from afnio.cognitive.parameter import Parameter # Get the singleton websocket client _, ws_client = get_default_clients() payload = { "data": self.data, "role": self.role, "requires_grad": self.requires_grad, "obj_type": ( "__parameter__" if isinstance(self, Parameter) else "__variable__" ), } response = run_in_background_loop( ws_client.call("create_variable", payload) ) if "error" in response: raise RuntimeError( response["error"]["data"].get("exception", response["error"]) ) logger.debug(f"Variable created and shared with the server: {self!r}") variable_id = response.get("result", {}).get("variable_id") if not variable_id: raise RuntimeError( f"Server did not return a variable_id " f"for payload: {payload!r}, response: {response!r}" ) self.variable_id = variable_id self._initialized = True register_variable(self) except Exception as e: logger.error(f"Failed to share Variable with the server: {e}") raise # TODO: pretty print data lists def __repr__(self): if self._grad_fn: return f"Variable(data={self.data}, role={self.role}, grad_fn={self._grad_fn.name()})" # noqa: E501 return f"Variable(data={self.data}, role={self.role}, requires_grad={self.requires_grad})" # noqa: E501 # TODO: pretty print data lists def __str__(self): # Helper function to truncate a string if it's longer than 40 characters def truncate_str(s): if isinstance(s, (int, float)): return str(s) if len(s) > 40: return f"{s[:20]}...{s[-20:]}" return s # Helper function to show the first and last three elements if it is long def format_list(data_list): if len(data_list) > 6: truncated = [ truncate_str(d) for d in (data_list[:3] + ["..."] + data_list[-3:]) ] return f"[{', '.join(truncated)}]" return f"[{', '.join(truncate_str(d) for d in data_list)}]" if isinstance(self.data, list): data_repr = format_list(self.data) else: data_repr = truncate_str(self.data) if self._grad_fn: return f"variable({data_repr}, role={truncate_str(self.role)}, grad_fn={self._grad_fn.name()})" # noqa: E501 return f"variable({data_repr}, role={truncate_str(self.role)}, requires_grad={self.requires_grad})" # noqa: E501 def __add__(self, other) -> "Variable": if not isinstance(other, Variable): raise TypeError("Only Variables can be added to each other.") from afnio.autodiff.basic_ops import Add return Add.apply(self, other) def __iadd__(self, other) -> "Variable": if not isinstance(other, Variable): raise TypeError("Only Variables can be added to each other.") from afnio.autodiff.basic_ops import Add result = Add.apply(self, other) self.data = result.data self.role = result.role self.requires_grad = result.requires_grad # Update the grad function in case `other` also has `requires_grad` if result.requires_grad: with _allow_grad_fn_assignment(): self.grad_fn = result.grad_fn return self
[docs] def backward( self, gradient=None, retain_graph=None, create_graph=False, inputs=None ) -> None: r"""Computes the gradient of current variable wrt graph leaves. The graph is differentiated using the chain rule. If the variable is non-scalar (i.e. its data has more than one element) and requires gradient, the function additionally requires specifying a ``gradient``. It should be a variable with data of matching type and shape, that represents the gradient of the differentiated function w.r.t. ``self``. This function accumulates gradients in the leaves - you might need to zero ``.grad`` attributes or set them to ``None`` before calling it. .. note:: When ``inputs`` are provided, each input must be a leaf variable. If any input is not a leaf, a ``RuntimeError`` is raised. Args: gradient (Variable, optional): The gradient of the function being differentiated w.r.t. ``self``. This argument can be omitted if ``self`` is a scalar. retain_graph (bool, optional): If ``False``, the graph used to compute the grads will be freed. Setting this to ``True`` retains the graph, allowing for additional backward calls on the same graph, useful for example for multi-task learning where you have multiple losses. However, retaining the graph is not needed in nearly all cases and can be worked around in a much more efficient way. Defaults to the value of ``create_graph``. create_graph (bool, optional): If ``True``, graph of the derivative will be constructed, allowing to compute higher order derivative products. Defaults to ``False``. inputs (sequence of Variable, optional): Inputs w.r.t. which the gradient will be accumulated into ``.grad``. All other variables will be ignored. If not provided, the gradient is accumulated into all the leaf Variables that were used to compute the :attr:`variables`. """ if self.is_leaf: raise RuntimeError( "Variable does not require grad or does not have a grad_fn." ) afnio.autodiff.backward( self, gradient, retain_graph, create_graph, inputs=inputs )
[docs] def requires_grad_(self, mode: bool = True) -> "Variable": r""" requires_grad_(requires_grad=True) -> Variable Change if autodiff should record operations on this variable: sets this variable's :attr:`requires_grad` attribute in-place. Returns this variable. :func:`requires_grad_`'s main use case is to tell autodiff to begin recording operations on a Variable ``variable``. If ``variable`` has ``requires_grad=False`` (because it was obtained through a DataLoader, or required preprocessing or initialization), ``variable.requires_grad_()`` makes it so that autodiff will begin to record operations on ``variable``. Args: requires_grad (bool): If autodiff should record operations on this variable. Default: ``True``. Example: >>> # Initialize with requires_grad=False for data preprocessing >>> x = hf.Variable(data="abc", role="input") >>> x = preprocess(x) # Preprocess without gradient tracking >>> x variable(abc, role=input, requires_grad=False) >>> # Now enable requires_grad for backpropagation >>> x.requires_grad_() >>> output = model(x) >>> output.backward() # Backpropagation through `x` >>> x.grad variable(ABC, role=input, requires_grad=True) """ self.requires_grad = mode self.is_leaf = not self.requires_grad or self.grad_fn is None return self
@property def data(self): self._wait_for_pending( "_pending_data" ) # Wait until the pending flag is cleared return self._data @data.setter def data(self, value): self._data = value @property def output_nr(self) -> int: return self._output_nr @output_nr.setter def output_nr(self, n: int): if not isinstance(n, int) or not (n >= 0): raise TypeError( f"`output_nr` can only be an int greater or equal to 0, " f"but {n} is of type {type(n).__name__}" ) self._output_nr = n @property def grad_fn(self) -> Optional["Node"]: self._wait_for_pending( "_pending_grad_fn_id" ) # Wait until the pending flag is cleared return self._grad_fn @grad_fn.setter def grad_fn(self, fn: Callable): """ Sets the ``grad_fn`` that will be called by the engine to produce the actual gradient for this variable. """ if not getattr(_grad_fn_assignment_allowed, "value", False): raise AttributeError( "Direct assignment to `grad_fn` is not allowed. " "Use Function.apply() to construct Variables with a grad_fn." ) if not self.requires_grad: raise RuntimeError( "Cannot set `grad_fn` on a variable that does not require gradients. " "To enable gradient tracking for this variable, call " "`.requires_grad_()` before setting `grad_fn`. Only variables with " "`requires_grad=True` can have a gradient function (`grad_fn`)." ) self._grad_fn = fn self.is_leaf = not self.requires_grad or self.grad_fn is None @property def grad(self) -> Optional["Variable"]: self._wait_for_pending( "_pending_grad" ) # Wait until the pending flag is cleared if self.is_leaf or self._retain_grad: return self._grad else: # Throwing a `UserWarning`` instead of `RuntimeError` could do here, like # in Pytorch, but for now I cannot think of any use case for not throwing # the error raise RuntimeError( "Attempted to access .grad for a non-leaf Variable without retain_grad " "enabled. Non-leaf Variables do not have their gradients retained by " "default in autodiff. To retain gradients for this Variable, call " "``.retain_grad()`` before performing the backward pass." ) @grad.setter def grad(self, gradient: List["Variable"]): """ Sets the ``.grad`` for this variable if it is a leaf or has ``.retain_grad`` enabled. """ if not isinstance(gradient, list) or not all( isinstance(g, Variable) for g in gradient ): raise TypeError( f"`.grad` expects a list of Variables for the gradient to accumulate, " f"but got {type(gradient).__name__}." ) if self.is_leaf or self._retain_grad: self._grad = gradient else: # Throwing a `UserWarning`` instead of `RuntimeError` could do here, like # in Pytorch, but for now I cannot think of any use case for not throwing # the error raise RuntimeError( "Attempted to set .grad for a non-leaf Variable without retain_grad " "enabled. Non-leaf Variables do not have their gradients retained by " "default in autodiff. To retain gradients for this Variable, call " "``.retain_grad()`` before performing the backward pass." )
[docs] def append_grad(self, gradient: "Variable"): """ Appends a gradient value to the list ``.grad`` for this variable. """ if self.is_leaf or self._retain_grad: self._on_append_grad(gradient) self._grad.append(gradient) else: # Throwing a `UserWarning`` instead of `RuntimeError` could do here, like # in Pytorch, but for now I cannot think of any use case for not throwing # the error raise RuntimeError( "Attempted to append to .grad for a non-leaf Variable without " "retain_grad enabled. Non-leaf Variables do not have their gradients " "retained by default in autodiff. To retain gradients for this " "Variable, call ``.retain_grad()`` before performing the backward pass." )
[docs] def retain_grad(self): """Enable gradient retention for non-leaf variables.""" if not self.is_leaf: self._retain_grad = True else: raise RuntimeError("Cannot call retain_grad on a leaf variable")
[docs] def detach(self) -> "Variable": """ Returns a new Variable, detached from the computation graph. This new Variable will not have a `grad_fn` and will not track gradients. """ return Variable(self.data, role=self.role, requires_grad=False)
# def clone(self): # """ # Create a copy of this Variable, preserving the data. # """ # return copy.deepcopy(self) def __deepcopy__(self, memo): if not self.is_leaf: raise RuntimeError( "Only Variables created explicitly by the user " "(graph leaves) support the deepcopy protocol at the moment." ) if id(self) in memo: return memo[id(self)] with afnio.no_grad(): new_variable = Variable( data=deepcopy(self.data, memo), role=self.role, requires_grad=self.requires_grad, ) new_variable._retain_grad = self._retain_grad new_variable._output_nr = self._output_nr if self.grad_fn: with _allow_grad_fn_assignment(): new_variable.grad_fn = deepcopy( self.grad_fn, memo ) # Also sets `.is_leaf` if self.grad != []: new_variable.grad = deepcopy(self.grad, memo) new_variable.__dict__ = deepcopy(self.__dict__, memo) memo[id(self)] = new_variable return new_variable
[docs] def copy_(self, src: "Variable") -> "Variable": """ Copies the data from the source Variable into this Variable. Args: src (Variable): The source Variable to copy from. Returns: self: The current Variable with updated data, role and requires_grad. Raises: TypeError: If the source is not a Variable. ValueError: If the source data type does not match the target data type. """ if not is_variable(src): raise TypeError( f"Expected `src` to be a Variable, but got {type(src).__name__}." ) is_scalar_self = is_scalar_variable(self) is_scalar_src = is_scalar_variable(src) if is_scalar_self and is_scalar_src: self.data = src.data elif not is_scalar_self and not is_scalar_src: if len(self.data) != len(src.data): raise ValueError( f"Cannot copy list `.data` fields of different lengths: " f"{len(self.data)} vs {len(src.data)}." ) self.data = src.data.copy() else: raise ValueError( f"Cannot copy data from {type(src.data).__name__} " f"to {type(self.data).__name__}." ) self.role = src.role self.requires_grad = src.requires_grad return self
[docs] def is_floating_point(self) -> bool: """ Checks if the Variable's data contains floating-point values. Returns: bool: True if the data is a floating-point type (either scalar or all elements in a list/tuple are floating-point). """ if isinstance(self.data, float): return True if isinstance(self.data, (list, tuple)) and self.data: return all(isinstance(d, float) for d in self.data) return False
[docs] def to(self, dtype=None) -> "Variable": """ Cast the data of the Variable to the specified dtype. Args: dtype (Optional[type]): The target type to cast the data (e.g., float, int, str). Returns: Variable: A new Variable with data cast to the target dtype. """ if dtype is not None: if not is_scalar_variable(self): # Cast each element in the list to the target dtype new_data = [dtype(d) for d in self.data] else: # Cast scalar data to the target dtype new_data = dtype(self.data) else: # No dtype casting new_data = self.data # Return a new Variable with the same role and requires_grad, but updated data return Variable(data=new_data, role=self.role, requires_grad=self.requires_grad)
def _on_variable_change(self, field: str, value): """ Notify the server of a change in the variable's attributes. This method is called whenever an attribute of the variable is set. It sends a notification to the server with the updated field and value. Args: field (str): The name of the field that changed. value: The new value of the field. Raises: RuntimeError: If the variable is not registered with the server or if the server response does not match the request. TypeError: If the provided value is of an unexpected type for the field. """ from afnio._utils import _serialize_arg if is_variable_notify_suppressed(): return # Do not notify server if self.variable_id is None: logger.error( f"Cannot notify server: " f"variable_id=None, field='{field}', value={value!r}" ) raise RuntimeError("Cannot notify server: variable_id is None.") if field in { "output_nr", "grad_fn", "grad", "_initialized", "_pending_grad_fn_id", "_pending_grad", "_pending_data", "__dict__", # Avoids server error when calling `Optimizer.load_state_dict` }: # Do not notify for the property setter, as we already notify # for all the changes made inside the property setter. # Also do not notify for `_initialized` and pending states return elif field == "_data": field = "data" # `data` is a property only on the client end_value = value elif field == "_grad": if not isinstance(value, list): raise TypeError( f"Expected `value` to be a list for field '{field}', " f"but got {type(value).__name__}." ) end_value = [_serialize_arg(g) for g in value] elif field == "_grad_fn": # Only allow notification if inside the `__iadd__` method if not _called_directly_from_iadd(): raise RuntimeError( "Setting `grad_fn` is only allowed on the server by the autodiff " "engine. Do not use `_allow_grad_fn_assignment()` on the client." ) end_value = value.node_id # Use only the node ID for notification else: end_value = value payload = { "variable_id": self.variable_id, "field": field, "value": end_value, } try: _, ws_client = get_default_clients() response = run_in_background_loop( ws_client.call("update_variable", payload) ) if "error" in response: raise RuntimeError( response["error"]["data"].get("exception", response["error"]) ) # Check server response if ( response["result"]["variable_id"] != self.variable_id or response["result"]["field"] != field or response["result"]["value"] != end_value ): raise RuntimeError( f"Server response mismatch: (received {response['result']!r}, " f"but expected variable_id={self.variable_id!r}, field={field!r}, " f"value={end_value!r})" ) logger.debug( f"Variable change notified to server and confirmed: " f"variable_id={self.variable_id!r}, field='{field}', " f"value={end_value!r}" ) except Exception as e: logger.exception(f"Failed to notify server of variable change: {e}") raise def _on_append_grad(self, gradient: "Variable"): """ Notify the server that a new gradient has been appended to this variable. This method is called before the gradient is added to the local `.grad` list. It sends an 'append_grad' RPC request to the server, including the variable's ID and the serialized gradient. The method blocks until the server acknowledges the append operation, ensuring synchronization between client and server. Args: gradient (Variable): The gradient variable to append. Raises: RuntimeError: If the variable is not registered with the server or if the server response does not match the request. TypeError: If the provided gradient is not a Variable. """ from afnio._utils import _serialize_arg if is_variable_notify_suppressed(): return # Do not notify server if self.variable_id is None: logger.error( f"Cannot notify server: variable_id=None, gradient={gradient!r}" ) raise RuntimeError("Cannot notify server: variable_id is None.") if not isinstance(gradient, Variable): raise TypeError( f"Expected `value` to be a Variable, but got {type(gradient).__name__}." ) ser_grad = _serialize_arg(gradient) payload = { "variable_id": self.variable_id, "gradient": ser_grad, } try: _, ws_client = get_default_clients() response = run_in_background_loop(ws_client.call("append_grad", payload)) if "error" in response: raise RuntimeError( response["error"]["data"].get("exception", response["error"]) ) # Check server response if ( response["result"]["variable_id"] != self.variable_id or response["result"]["gradient_id"] != gradient.variable_id ): raise RuntimeError( f"Server response mismatch: (received {response['result']!r}, " f"but expected variable_id={self.variable_id!r}, " f"gradient={ser_grad!r}" ) logger.debug( f"Gradient append notified to server and confirmed: " f"variable_id={self.variable_id!r}, gradient={ser_grad!r}" ) except Exception as e: logger.exception(f"Failed to notify server of gradient append: {e}") raise def __setattr__(self, name, value): super().__setattr__(name, value) if getattr(self, "_initialized", False): self._on_variable_change(name, value) # TODO: Should we handle the else condition and throw an error? def _wait_for_pending( self, attr_name: str, timeout: float = 3, interval: float = 0.01 ) -> None: """ Wait until the attribute specified by `attr_name` is no longer truthy. Uses time.monotonic() for more reliable timeout measurement. Args: attr_name (str): Name of the attribute to wait on. timeout (float): Maximum time to wait, in seconds. interval (float): How frequently to check the attribute, in seconds. Raises: RuntimeError: If the attribute remains truthy after the timeout. """ end_time = time.monotonic() + timeout while getattr(self, attr_name): if time.monotonic() > end_time: raise RuntimeError( f"Timeout waiting for {attr_name} to be cleared " f"for variable_id={self.variable_id}" ) time.sleep(interval)
def is_variable(obj): r"""Returns True if `obj` is an Afnio variable. Note that this function is simply doing ``isinstance(obj, hf.Variable)``. Using that ``isinstance`` check is better for typechecking with mypy, and more explicit - so it's recommended to use that instead of ``is_variable``. Use ``is_variable`` for example when importing ``Variable`` creates circular dependencies. Args: obj (Object): Object to test Example:: >>> x = hf.Variable("abc") >>> hf.is_variable(x) True """ return isinstance(obj, Variable) def is_scalar_variable(obj): """ Check if an object is a Variable and its `.data` is a scalar (of type str, int, or float). Args: obj: The object to check. Returns: bool: True if the object is a scalar Variable, False otherwise. """ if not is_variable(obj): return False data = getattr(obj, "data", None) return isinstance(data, (str, int, float)) @contextmanager def _allow_grad_fn_assignment(): """ Context manager that allows assignment to the `grad_fn` attribute of Variables. This is useful for internal operations where you need to set the `grad_fn` directly, bypassing the usual restrictions. .. note:: This context manager should only be used by the autodiff engine, as it allows direct manipulation of the `grad_fn` attribute, which is typically managed internally. Manual use is strongly discouraged. """ previous_state = getattr(_grad_fn_assignment_allowed, "value", False) _grad_fn_assignment_allowed.value = True # Allow grad_fn assignment try: yield # Execute the block finally: _grad_fn_assignment_allowed.value = previous_state # Restore the original state def _called_directly_from_iadd(): """ Check if the current function call stack indicates that we are being called directly from the `Variable.__iadd__` method. """ stack = inspect.stack() # Look for the frame corresponding to __iadd__ for frame in stack: if frame.function == "__iadd__": # Check filename if frame.filename.endswith("_variable.py"): # Check if 'self' is in locals and is a Variable self_obj = frame.frame.f_locals.get("self") if self_obj is not None and type(self_obj).__name__ == "Variable": return True return False