Skip to content

Custom Autograd Functions

TensorPlay allows you to extend its automatic differentiation engine by defining your own forward and backward logic. This is useful for implementing non-differentiable operations with custom gradients or optimizing specific kernels.

Defining a Custom Function

To create a custom autograd function, inherit from tpx.autograd.Function and implement the static methods forward and backward.

python
import tensorplay as tp
from tensorplay.autograd import Function

class MyExp(Function):
    @staticmethod
    def forward(ctx, i):
        result = i.exp()
        ctx.save_for_backward(result)
        return result

    @staticmethod
    def backward(ctx, grad_output):
        result, = ctx.saved_tensors
        return grad_output * result

# Usage
def my_exp(x):
    return MyExp.apply(x)

x = tp.randn(3, requires_grad=True)
y = my_exp(x)
y.sum().backward()

print(x.grad)

When to Use Custom Functions?

  1. Efficiency: If you can compute a combined gradient more efficiently than the composition of standard operations.
  2. Stability: Implementing numerically stable versions of functions (e.g., LogSumExp).
  3. Custom Hardware: Interfacing with hardware that has its own differentiation logic.
  4. Non-differentiable Ops: Providing a "surrogate gradient" for operations like Step functions or Quantization.

Context Object (ctx)

The ctx object is used to communicate information from the forward pass to the backward pass:

  • ctx.save_for_backward(*tensors): Use this to save tensors needed for the backward pass.
  • ctx.saved_tensors: Access the saved tensors in backward.
  • ctx.mark_dirty(*tensors): Mark tensors that are modified in-place.
  • ctx.mark_non_differentiable(*tensors): Mark outputs that don't require gradients.

Released under the Apache 2.0 License.

📚DeepWiki