def _max_pool_grad_grad(dy, x, y, ksize, strides, padding, argmax=None):
"""Gradients of MaxPoolGrad."""
if argmax is None:
_, argmax = tf.nn.max_pool_with_argmax(x, ksize, strides, padding)
grad = dy
grad_flat = tf.reshape(grad, [-1])
argmax_flat = tf.reshape(argmax, [-1])
x_shape = tf.cast(tf.shape(x), argmax.dtype)
batch_dim = tf.reshape(
tf.range(
x_shape[0], dtype=argmax.dtype), [-1, 1, 1, 1])
nelem = tf.reduce_prod(x_shape[1:])
batch_dim *= nelem
y_zero = tf.zeros_like(y, dtype=argmax.dtype)
batch_dim += y_zero
batch_dim = tf.reshape(batch_dim, [-1])
argmax_flat += batch_dim
grad_input = tf.gather(grad_flat, argmax_flat)
grad_input = tf.reshape(grad_input, tf.shape(y))
return grad_input
maxpool_gradgrad.py 文件源码
python
阅读 26
收藏 0
点赞 0
评论 0
评论列表
文章目录