def max_pool_backward_reshape(dout, cache):
"""
A fast implementation of the backward pass for the max pooling layer that
uses some clever broadcasting and reshaping.
This can only be used if the forward pass was computed using
max_pool_forward_reshape.
NOTE: If there are multiple argmaxes, this method will assign gradient to
ALL argmax elements of the input rather than picking one. In this case the
gradient will actually be incorrect. However this is unlikely to occur in
practice, so it shouldn't matter much. One possible solution is to split the
upstream gradient equally among all argmax elements; this should result in a
valid subgradient. You can make this happen by uncommenting the line below;
however this results in a significant performance penalty (about 40% slower)
and is unlikely to matter in practice so we don't do it.
"""
x, x_reshaped, out = cache
dx_reshaped = np.zeros_like(x_reshaped)
out_newaxis = out[:, :, :, np.newaxis, :, np.newaxis]
mask = (x_reshaped == out_newaxis)
dout_newaxis = dout[:, :, :, np.newaxis, :, np.newaxis]
dout_broadcast, _ = np.broadcast_arrays(dout_newaxis, dx_reshaped)
dx_reshaped[mask] = dout_broadcast[mask]
dx_reshaped /= np.sum(mask, axis=(3, 5), keepdims=True)
dx = dx_reshaped.reshape(x.shape)
return dx
评论列表
文章目录