def MaxPool_FwGrad(op,
dx,
ksize=[1, 2, 2, 1],
strides=[1, 2, 2, 1],
padding="SAME",
_op_table=None,
_grad_table=None):
"""Forward gradient operator for max pooling.
Args:
x: Input tensor, 4D tensor, [N, H, W, C].
dx: Gradient of the input tensor, 4D tensor, [N, H, W, C].
ksize: Kernel size of the max pooling operator, list of integers.
strides: Strides of the max pooling operator, list of integers.
padding: Padding, string, "SAME" or "VALID".
data_format: "NHWC" or "NCHW".
"""
if dx is None:
return None
x = op.inputs[0]
y = op.outputs[0]
_, argmax = tf.nn.max_pool_with_argmax(x, ksize, strides, padding)
dx_flat = tf.reshape(dx, [-1])
argmax_flat = tf.reshape(argmax, [-1])
y_zero = tf.zeros_like(y, dtype=argmax.dtype)
x_shape = tf.cast(tf.shape(x), argmax.dtype)
batch_dim = tf.reshape(
tf.range(
x_shape[0], dtype=argmax.dtype), [-1, 1, 1, 1])
nelem = tf.reduce_prod(x_shape[1:])
batch_dim *= nelem
batch_dim += y_zero
batch_dim = tf.reshape(batch_dim, [-1])
argmax_flat += batch_dim
dx_sel = tf.gather(dx_flat, argmax_flat)
dy = tf.reshape(dx_sel, tf.shape(argmax))
return dy
评论列表
文章目录