Skip to content

tensorplay.autograd.grad_mode

Classes

class enable_grad [source]

python
enable_grad(orig_func=None)

Bases: _NoParamDecoratorContextManager

Context-manager that enables gradient calculation.

Enables gradient calculation, if it has been disabled via ~no_grad or ~set_grad_enabled.

This context manager is thread local; it will not affect computation in other threads.

Also functions as a decorator.

INFO

enable_grad is one of several mechanisms that can enable or disable gradients locally see locally-disable-grad-doc for more information on how they compare.

INFO

This API does not apply to forward-mode AD <forward-mode-ad>.

Example

python
# xdoctest: +SKIP
x = tensorplay.tensor([1.], requires_grad=True)
with tensorplay.no_grad():
    with tensorplay.enable_grad():
        y = x * 2
y.requires_grad
True
y.backward()
x.grad
tensor([2.])
@tensorplay.enable_grad()
def doubler(x):
    return x * 2
with tensorplay.no_grad():
    z = doubler(x)
z.requires_grad
True
@tensorplay.enable_grad()
def tripler(x):
    return x * 3
with tensorplay.no_grad():
    z = tripler(x)
z.requires_grad
True
Methods

clone(self) [source]


class inference_mode [source]

python
inference_mode(mode=True)

Bases: _DecoratorContextManager

Context manager that enables or disables inference mode.

InferenceMode is analogous to ~no_grad and should be used when you are certain your operations will not interact with autograd (e.g., during data loading or model evaluation). Compared to ~no_grad, it removes additional overhead by disabling view tracking and version counter bumps. It is also more restrictive, in that tensors created in this mode cannot be used in computations recorded by autograd.

This context manager is thread-local; it does not affect computation in other threads.

Also functions as a decorator.

INFO

Inference mode is one of several mechanisms that can locally enable or disable gradients. See locally-disable-grad-doc for a comparison. If avoiding the use of tensors created in inference mode in autograd-tracked regions is difficult, consider benchmarking your code with and without inference mode to weigh the performance benefits against the trade-offs. You can always use ~no_grad instead.

.. note:: Unlike some other mechanisms that locally enable or disable grad, entering inference_mode also disables forward-mode AD <forward-mode-ad>.

Args

  • mode (bool or function): Either a boolean flag to enable or disable inference mode, or a Python function to decorate with inference mode enabled.

Example

python
import tensorplay
x = tensorplay.ones(1, 2, 3, requires_grad=True)
with tensorplay.inference_mode():
    y = x * x
y.requires_grad
False
y._version
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: Inference tensors do not track version counter.
@tensorplay.inference_mode()
def func(x):
    return x * x
out = func(x)
out.requires_grad
False
@tensorplay.inference_mode()
def doubler(x):
    return x * 2
out = doubler(x)
out.requires_grad
False
Methods

__init__(self, mode: bool = True) -> None [source]

Initialize self. See help(type(self)) for accurate signature.


clone(self) -> 'inference_mode' [source]

Create a copy of this class


class no_grad [source]

python
no_grad() -> None

Bases: _NoParamDecoratorContextManager

Context-manager that disables gradient calculation.

Disabling gradient calculation is useful for inference, when you are sure that you will not call Tensor.backward(). It will reduce memory consumption for computations that would otherwise have requires_grad=True.

In this mode, the result of every computation will have requires_grad=False, even when the inputs have requires_grad=True. There is an exception! All factory functions, or functions that create a new Tensor and take a requires_grad kwarg, will NOT be affected by this mode.

This context manager is thread local; it will not affect computation in other threads.

Also functions as a decorator.

INFO

No-grad is one of several mechanisms that can enable or disable gradients locally see locally-disable-grad-doc for more information on how they compare.

INFO

This API does not apply to forward-mode AD <forward-mode-ad>. If you want to disable forward AD for a computation, you can unpack your dual tensors.

Example

python
x = tensorplay.tensor([1.], requires_grad=True)
with tensorplay.no_grad():
    y = x * 2
y.requires_grad
False
@tensorplay.no_grad()
def doubler(x):
    return x * 2
z = doubler(x)
z.requires_grad
False
@tensorplay.no_grad()
def tripler(x):
    return x * 3
z = tripler(x)
z.requires_grad
False
# factory function exception
with tensorplay.no_grad():
    a = tensorplay.nn.Parameter(tensorplay.rand(10))
a.requires_grad
True
Methods

__init__(self) -> None [source]

Initialize self. See help(type(self)) for accurate signature.


clone(self) [source]


class set_grad_enabled [source]

python
set_grad_enabled(mode: bool) -> None

Bases: _DecoratorContextManager

Context-manager that sets gradient calculation on or off.

set_grad_enabled will enable or disable grads based on its argument mode. It can be used as a context-manager or as a function.

This context manager is thread local; it will not affect computation in other threads.

Args

  • mode (bool): Flag whether to enable grad (True), or disable (False). This can be used to conditionally enable gradients.

INFO

set_grad_enabled is one of several mechanisms that can enable or disable gradients locally see locally-disable-grad-doc for more information on how they compare.

INFO

This API does not apply to forward-mode AD <forward-mode-ad>.

Example

python
# xdoctest: +SKIP
x = tensorplay.tensor([1.], requires_grad=True)
is_train = False
with tensorplay.set_grad_enabled(is_train):
    y = x * 2
y.requires_grad
False
_ = tensorplay.set_grad_enabled(True)
y = x * 2
y.requires_grad
True
_ = tensorplay.set_grad_enabled(False)
y = x * 2
y.requires_grad
False
Methods

__init__(self, mode: bool) -> None [source]

Initialize self. See help(type(self)) for accurate signature.


clone(self) -> 'set_grad_enabled' [source]

Create a copy of this class


Functions

is_grad_enabled() [source]

python
is_grad_enabled()

Released under the Apache 2.0 License.

📚DeepWiki