Custom Autograd Functions
TensorPlay allows you to extend its automatic differentiation engine by defining your own forward and backward logic. This is useful for implementing non-differentiable operations with custom gradients or optimizing specific kernels.
Defining a Custom Function
To create a custom autograd function, inherit from tpx.autograd.Function and implement the static methods forward and backward.
python
import tensorplay as tp
from tensorplay.autograd import Function
class MyExp(Function):
@staticmethod
def forward(ctx, i):
result = i.exp()
ctx.save_for_backward(result)
return result
@staticmethod
def backward(ctx, grad_output):
result, = ctx.saved_tensors
return grad_output * result
# Usage
def my_exp(x):
return MyExp.apply(x)
x = tp.randn(3, requires_grad=True)
y = my_exp(x)
y.sum().backward()
print(x.grad)When to Use Custom Functions?
- Efficiency: If you can compute a combined gradient more efficiently than the composition of standard operations.
- Stability: Implementing numerically stable versions of functions (e.g., LogSumExp).
- Custom Hardware: Interfacing with hardware that has its own differentiation logic.
- Non-differentiable Ops: Providing a "surrogate gradient" for operations like Step functions or Quantization.
Context Object (ctx)
The ctx object is used to communicate information from the forward pass to the backward pass:
ctx.save_for_backward(*tensors): Use this to save tensors needed for the backward pass.ctx.saved_tensors: Access the saved tensors inbackward.ctx.mark_dirty(*tensors): Mark tensors that are modified in-place.ctx.mark_non_differentiable(*tensors): Mark outputs that don't require gradients.
