rev_block.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:tensor2tensor 作者: tensorflow 项目源码 文件源码
def _recompute_grad(fn, args):
  """See recompute_grad."""

  cached_vs = []
  cached_arg_scope = []

  def grad_fn(inputs, variables, outputs, output_grads):
    """Recompute outputs for gradient computation."""
    del outputs
    # Recompute outputs
    with tf.control_dependencies(output_grads):
      with tf.contrib.framework.arg_scope(cached_arg_scope[0]):
        with tf.variable_scope(cached_vs[0], reuse=True):
          outputs = fn(*inputs)

    if not (isinstance(outputs, list) or isinstance(outputs, tuple)):
      outputs = [outputs]
    outputs = list(outputs)
    grads = tf.gradients(outputs, inputs + variables, output_grads)
    grad_inputs = grads[:len(inputs)]
    grad_vars = grads[len(inputs):]
    return grad_inputs, grad_vars

  @common_layers.fn_with_custom_grad(grad_fn)
  def fn_with_recompute(*args):
    cached_vs.append(tf.get_variable_scope())
    # TODO(rsepassi): Rm conditional in TF 1.5
    if hasattr(tf.contrib.framework, "current_arg_scope"):
      cached_arg_scope.append(tf.contrib.framework.current_arg_scope())
    else:
      cached_arg_scope.append({})
    return fn(*args)

  return fn_with_recompute(*args)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号