fast_layers.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:DLCourse_AS 作者: TomoyaKung 项目源码 文件源码
def max_pool_backward_reshape(dout, cache):
  """
  A fast implementation of the backward pass for the max pooling layer that
  uses some clever broadcasting and reshaping.

  This can only be used if the forward pass was computed using
  max_pool_forward_reshape.

  NOTE: If there are multiple argmaxes, this method will assign gradient to
  ALL argmax elements of the input rather than picking one. In this case the
  gradient will actually be incorrect. However this is unlikely to occur in
  practice, so it shouldn't matter much. One possible solution is to split the
  upstream gradient equally among all argmax elements; this should result in a
  valid subgradient. You can make this happen by uncommenting the line below;
  however this results in a significant performance penalty (about 40% slower)
  and is unlikely to matter in practice so we don't do it.
  """
  x, x_reshaped, out = cache

  dx_reshaped = np.zeros_like(x_reshaped)
  out_newaxis = out[:, :, :, np.newaxis, :, np.newaxis]
  mask = (x_reshaped == out_newaxis)
  dout_newaxis = dout[:, :, :, np.newaxis, :, np.newaxis]
  dout_broadcast, _ = np.broadcast_arrays(dout_newaxis, dx_reshaped)
  dx_reshaped[mask] = dout_broadcast[mask]
  dx_reshaped /= np.sum(mask, axis=(3, 5), keepdims=True)
  dx = dx_reshaped.reshape(x.shape)

  return dx
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号