def backward(self, grad_output):
input, target = self.saved_tensors
grad_input = input.new().resize_as_(input).copy_(target)
grad_input[torch.mul(torch.eq(target, -1), torch.gt(input, self.margin))] = 0
if self.size_average:
grad_input.mul_(1. / input.nelement())
if grad_output[0] != 1:
grad_input.mul_(grad_output[0])
return grad_input, None
评论列表
文章目录