fwgrad.py 文件源码

python
阅读 35 收藏 0 点赞 0 评论 0

项目:tensorflow-forward-ad 作者: renmengye 项目源码 文件源码
def MaxPool_FwGrad(op,
                   dx,
                   ksize=[1, 2, 2, 1],
                   strides=[1, 2, 2, 1],
                   padding="SAME",
                   _op_table=None,
                   _grad_table=None):
  """Forward gradient operator for max pooling.

  Args:
    x: Input tensor, 4D tensor, [N, H, W, C].
    dx: Gradient of the input tensor, 4D tensor, [N, H, W, C].
    ksize: Kernel size of the max pooling operator, list of integers.
    strides: Strides of the max pooling operator, list of integers.
    padding: Padding, string, "SAME" or "VALID".
    data_format: "NHWC" or "NCHW".
  """
  if dx is None:
    return None
  x = op.inputs[0]
  y = op.outputs[0]
  _, argmax = tf.nn.max_pool_with_argmax(x, ksize, strides, padding)
  dx_flat = tf.reshape(dx, [-1])
  argmax_flat = tf.reshape(argmax, [-1])
  y_zero = tf.zeros_like(y, dtype=argmax.dtype)
  x_shape = tf.cast(tf.shape(x), argmax.dtype)
  batch_dim = tf.reshape(
      tf.range(
          x_shape[0], dtype=argmax.dtype), [-1, 1, 1, 1])
  nelem = tf.reduce_prod(x_shape[1:])
  batch_dim *= nelem
  batch_dim += y_zero
  batch_dim = tf.reshape(batch_dim, [-1])
  argmax_flat += batch_dim
  dx_sel = tf.gather(dx_flat, argmax_flat)
  dy = tf.reshape(dx_sel, tf.shape(argmax))
  return dy
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号