def recompute_grad(fn):
"""Decorator that recomputes the function on the backwards pass.
Args:
fn: a function that takes Tensors (all as positional arguments) and returns
a tuple of Tensors.
Returns:
A wrapped fn that is identical to fn when called, but its activations will
be discarded and recomputed on the backwards pass (i.e. on a call to
tf.gradients).
"""
@functools.wraps(fn)
def wrapped(*args):
return _recompute_grad(fn, args)
return wrapped
评论列表
文章目录