def _recompute_grad(fn, args):
"""See recompute_grad."""
cached_vs = []
cached_arg_scope = []
def grad_fn(inputs, variables, outputs, output_grads):
"""Recompute outputs for gradient computation."""
del outputs
# Recompute outputs
with tf.control_dependencies(output_grads):
with tf.contrib.framework.arg_scope(cached_arg_scope[0]):
with tf.variable_scope(cached_vs[0], reuse=True):
outputs = fn(*inputs)
if not (isinstance(outputs, list) or isinstance(outputs, tuple)):
outputs = [outputs]
outputs = list(outputs)
grads = tf.gradients(outputs, inputs + variables, output_grads)
grad_inputs = grads[:len(inputs)]
grad_vars = grads[len(inputs):]
return grad_inputs, grad_vars
@common_layers.fn_with_custom_grad(grad_fn)
def fn_with_recompute(*args):
cached_vs.append(tf.get_variable_scope())
# TODO(rsepassi): Rm conditional in TF 1.5
if hasattr(tf.contrib.framework, "current_arg_scope"):
cached_arg_scope.append(tf.contrib.framework.current_arg_scope())
else:
cached_arg_scope.append({})
return fn(*args)
return fn_with_recompute(*args)
评论列表
文章目录