def forward(ctx, input, target, grad_output, margin, size_average):
ctx.margin = margin
ctx.size_average = size_average
ctx.save_for_backward(input, target, grad_output)
grad_input = input.new().resize_as_(input).copy_(target)
grad_input[torch.mul(torch.eq(target, -1), torch.gt(input, ctx.margin))] = 0
if ctx.size_average:
grad_input.mul_(1. / input.nelement())
if grad_output[0] != 1:
grad_input.mul_(grad_output[0])
return grad_input
评论列表
文章目录