def forward(ctx, input1, input2, y, margin, size_average):
ctx.margin = margin
ctx.size_average = size_average
ctx.w1 = input1.new()
ctx.w22 = input1.new()
ctx.w = input1.new()
ctx.w32 = input1.new()
ctx._outputs = input1.new()
_idx = input1.new().byte()
buffer = torch.mul(input1, input2)
torch.sum(buffer, 1, out=ctx.w1, keepdim=True)
epsilon = 1e-12
torch.mul(input1, input1, out=buffer)
torch.sum(buffer, 1, out=ctx.w22, keepdim=True).add_(epsilon)
ctx._outputs.resize_as_(ctx.w22).fill_(1)
torch.div(ctx._outputs, ctx.w22, out=ctx.w22)
ctx.w.resize_as_(ctx.w22).copy_(ctx.w22)
torch.mul(input2, input2, out=buffer)
torch.sum(buffer, 1, out=ctx.w32, keepdim=True).add_(epsilon)
torch.div(ctx._outputs, ctx.w32, out=ctx.w32)
ctx.w.mul_(ctx.w32)
ctx.w.sqrt_()
torch.mul(ctx.w1, ctx.w, out=ctx._outputs)
ctx._outputs = ctx._outputs.select(1, 0)
torch.eq(y, -1, out=_idx)
ctx._outputs[_idx] = ctx._outputs[_idx].add_(-ctx.margin).clamp_(min=0)
torch.eq(y, 1, out=_idx)
ctx._outputs[_idx] = ctx._outputs[_idx].mul_(-1).add_(1)
output = ctx._outputs.sum()
if ctx.size_average:
output = output / y.size(0)
ctx.save_for_backward(input1, input2, y)
return input1.new((output,))
评论列表
文章目录