# afnio/autodiff/__init__.py
import logging
from typing import Optional
from afnio._utils import _serialize_arg
from afnio.logging_config import configure_logging
from afnio.tellurio._client_manager import get_default_clients
from afnio.tellurio._eventloop import run_in_background_loop
from afnio.tellurio._variable_registry import get_variable
from .grad_mode import is_grad_enabled, no_grad, set_grad_enabled
from .utils import (
_VariableOrVariables,
_VariableOrVariablesOrGradEdge,
)
__all__ = [
"backward",
"is_grad_enabled",
"no_grad",
"set_grad_enabled",
]
# Please keep this list sorted
assert __all__ == sorted(__all__)
# Configure logging
configure_logging()
logger = logging.getLogger(__name__)
[docs]
def backward(
variables: _VariableOrVariables,
grad_variables: Optional[_VariableOrVariables] = None,
retain_graph: Optional[bool] = None,
create_graph: bool = False,
inputs: Optional[_VariableOrVariablesOrGradEdge] = None,
) -> None:
r"""Computes the sum of gradients of given variables with respect to graph
leaves.
The graph is differentiated using the chain rule. If any of ``variables``
are non-scalar (i.e. their data has more than one element) and require
gradient, then the Jacobian-vector product would be computed, in this
case the function additionally requires specifying ``grad_variables``.
It should be a sequence of matching length, that contains the "vector"
in the Jacobian-vector product, usually the gradient of the differentiated
function w.r.t. corresponding variables (``None`` is an acceptable value for
all variables that don't need gradient variables).
This function accumulates gradients in the leaves - you might need to zero
``.grad`` attributes or set them to ``None`` before calling it.
.. note::
Using this method with ``create_graph=True`` will create a reference cycle
between the parameter and its gradient which can cause a memory leak.
We recommend using ``autodiff.grad`` when creating the graph to avoid this.
If you have to use this function, make sure to reset the ``.grad`` fields of
your parameters to ``None`` after use to break the cycle and avoid the leak.
.. note::
When ``inputs`` are provided, each input must be a leaf variable. If any
input is not a leaf, a ``RuntimeError`` is raised.
Args:
variables (Sequence[Variables] or Variable): Variables of which the derivative
will be computed.
grad_variables (Sequence[Variable or None] or Variable, optional): The "vector"
in the Jacobian-vector product, usually gradients w.r.t. each element of
corresponding variables. None values can be specified for scalar Variables
or ones that don't require grad. If a None value would be acceptable for all
grad_variables, then this argument is optional.
retain_graph (bool, optional): If ``False``, the graph used to compute the grads
will be freed. Setting this to ``True`` retains the graph, allowing for
additional backward calls on the same graph, useful for example for
multi-task learning where you have multiple losses. However, retaining the
graph is not needed in nearly all cases and can be worked around in a much
more efficient way. Defaults to the value of ``create_graph``.
create_graph (bool, optional): If ``True``, graph of the derivative will
be constructed, allowing to compute higher order derivative products.
Defaults to ``False``.
inputs (Sequence[Variable] or Variable or Sequence[GradientEdge], optional):
Inputs w.r.t. which the gradient will be accumulated into ``.grad``. All
other Variables will be ignored. If not provided, the gradient is
accumulated into all the leaf Variables that were used to compute
the :attr:`variables`.
"""
# Serialize the arguments
serialized_variables = _serialize_arg(variables)
serialized_grad_variables = _serialize_arg(grad_variables)
serialized_retain_graph = _serialize_arg(retain_graph)
serialized_create_graph = _serialize_arg(create_graph)
serialized_inputs = _serialize_arg(inputs)
# Send the RPC call to the server
backprop_variable_ids = []
try:
# Get the singleton websocket client
_, ws_client = get_default_clients()
# Fetch all Variables which gradients will be computed during backpropagation
# and mark them as pending for grad update
payload = {"variables": serialized_variables}
response_ids = run_in_background_loop(
ws_client.call("get_backprop_ids", payload)
)
if "error" in response_ids:
raise RuntimeError(
response_ids["error"]["data"].get("exception", response_ids["error"])
)
logger.debug(
f"Fetched backpropagation variable IDs from the server: {response_ids!r}"
)
result_ids = response_ids.get("result", {})
backprop_variable_ids = result_ids.get("variable_ids", [])
if backprop_variable_ids:
for var_id in backprop_variable_ids:
var = get_variable(var_id)
if var is not None:
var._pending_grad = True
logger.debug(
f"Marked variable {var_id!r} as pending for grad update."
)
else:
logger.warning(
f"Variable id {var_id!r} returned for backward, "
"but not found in VARIABLE_REGISTRY."
)
# Run backward pass
payload = {
"variables": serialized_variables,
"grad_variables": serialized_grad_variables,
"retain_graph": serialized_retain_graph,
"create_graph": serialized_create_graph,
"inputs": serialized_inputs,
}
response_bwd = run_in_background_loop(ws_client.call("run_backward", payload))
if "error" in response_bwd:
raise RuntimeError(
response_bwd["error"]["data"].get("exception", response_bwd["error"])
)
logger.debug(
f"Backward pass instantiated and shared with the server: {variables!r}"
)
result_message = response_bwd.get("result", {}).get("message")
if result_message != "Backward pass executed successfully.":
raise RuntimeError(
f"Server did not return any data for backward pass: "
f"payload={payload!r}, response={response_bwd!r}"
)
logger.debug(
f"Backward pass executed successfully with variables: {variables!r}"
)
except Exception as e:
logger.error(f"Failed to share backward pass with the server: {e}")
# Clear all pending grad flags to avoid deadlocks
for var_id in backprop_variable_ids:
var = get_variable(var_id)
if var is not None:
var._pending_grad = False
logger.debug(
f"Marked variable {var_id!r} as not pending for grad update "
f"after error."
)
raise