def update_grad_input(self, x, grad_output):
self.grad_input = np.zeros_like(x)
if x.dims == 3:
C, H, W = x.shape
x_flat = x.view(C, H * W)
self.buffer = np.dot(grad_output, x_flat)
self.buffer += np.dot(grad_output.T, x_flat)
self.grad_input = self.buffer.view(C, H, W)
if x.dims == 4:
N, C, H, W = x.shape
x_flat = x.view(N, C, H * W)
self.buffer = np.tensordot(grad_output, x_flat, 2)
self.buffer += np.tensordot(grad_output.transpose(2, 3), x_flat, 2)
self.grad_input = self.buffer.view(N, C, H, W)
if self.normalize:
self.buffer /= C * H * W
return self.grad_input
评论列表
文章目录