def updateGradInput(self, input, y):
self.gradInput.resize_as_(input).copy_(y)
self.gradInput[torch.mul(torch.eq(y, -1), torch.gt(input, self.margin))] = 0
if self.sizeAverage:
self.gradInput.mul_(1. / input.nelement())
return self.gradInput
评论列表
文章目录