Source code for afnio.autodiff.grad_mode

import threading
from contextlib import contextmanager

# Thread-local flag to control gradient tracking
_grad_enabled = threading.local()
_grad_enabled.enabled = True  # By default, gradients are enabled


[docs] def is_grad_enabled() -> bool: """Check whether gradients are currently enabled.""" return getattr(_grad_enabled, "enabled", True)
[docs] def set_grad_enabled(mode: bool): """Set the global state of gradient tracking.""" _grad_enabled.enabled = mode
[docs] @contextmanager def no_grad(): """ Context manager that disables gradient calculation. All operations within this block will not track gradients, making them more memory-efficient. Disabling gradient calculation is useful for inference, when you are sure that you will not call :meth:`Variable.backward()`. It will reduce memory consumption for computations that would otherwise have `requires_grad=True`. In this mode, the result of every computation will have `requires_grad=False`, even when the inputs have `requires_grad=True`. There is an exception! All factory functions, or functions that create a new Variable and take a requires_grad kwarg, will NOT be affected by this mode. This context manager is thread local; it will not affect computation in other threads. Also functions as a decorator. Example:: >>> x = hf.Variable("abc", role="variable", requires_grad=True) >>> with hf.no_grad(): ... y = x + x >>> y.requires_grad False >>> @hf.no_grad() ... def doubler(x): ... return x + x >>> z = doubler(x) >>> z.requires_grad False >>> @hf.no_grad ... def tripler(x): ... return x + x + x >>> z = tripler(x) >>> z.requires_grad False >>> # factory function exception >>> with hf.no_grad(): ... a = hf.cognitive.Parameter("xyz") >>> a.requires_grad True """ previous_state = is_grad_enabled() # Store the current state set_grad_enabled(False) # Disable gradients try: yield # Execute the block finally: set_grad_enabled(previous_state) # Restore the original state