def backward(self, grad_output):
input, filter = self.saved_tensors
grad_input = convolve2d(grad_output.numpy(), filter.t().numpy(), mode='full')
grad_filter = convolve2d(input.numpy(), grad_output.numpy(), mode='valid')
return torch.FloatTensor(grad_input), torch.FloatTensor(grad_filter)
评论列表
文章目录