maxpool_gradgrad.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:tensorflow-forward-ad 作者: renmengye 项目源码 文件源码
def _max_pool_grad_grad(dy, x, y, ksize, strides, padding, argmax=None):
  """Gradients of MaxPoolGrad."""
  if argmax is None:
    _, argmax = tf.nn.max_pool_with_argmax(x, ksize, strides, padding)
  grad = dy
  grad_flat = tf.reshape(grad, [-1])
  argmax_flat = tf.reshape(argmax, [-1])

  x_shape = tf.cast(tf.shape(x), argmax.dtype)
  batch_dim = tf.reshape(
      tf.range(
          x_shape[0], dtype=argmax.dtype), [-1, 1, 1, 1])
  nelem = tf.reduce_prod(x_shape[1:])
  batch_dim *= nelem

  y_zero = tf.zeros_like(y, dtype=argmax.dtype)
  batch_dim += y_zero
  batch_dim = tf.reshape(batch_dim, [-1])

  argmax_flat += batch_dim
  grad_input = tf.gather(grad_flat, argmax_flat)
  grad_input = tf.reshape(grad_input, tf.shape(y))
  return grad_input
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号