def forward(self, input, target):
buffer = input.new()
buffer.resize_as_(input).copy_(input)
buffer[torch.eq(target, -1.)] = 0
output = buffer.sum()
buffer.fill_(self.margin).add_(-1, input)
buffer.cmax_(0)
buffer[torch.eq(target, 1.)] = 0
output += buffer.sum()
if self.size_average:
output = output / input.nelement()
self.save_for_backward(input, target)
return input.new((output,))
评论列表
文章目录