def forward(ctx, input, target, margin, size_average):
ctx.margin = margin
ctx.size_average = size_average
buffer = input.new()
buffer.resize_as_(input).copy_(input)
buffer[torch.eq(target, -1.)] = 0
output = buffer.sum()
buffer.fill_(ctx.margin).add_(-1, input)
buffer.clamp_(min=0)
buffer[torch.eq(target, 1.)] = 0
output += buffer.sum()
if ctx.size_average:
output = output / input.nelement()
ctx.save_for_backward(input, target)
return input.new((output,))
评论列表
文章目录