def forward(self, input1, input2, y):
self.w1 = input1.new()
self.w22 = input1.new()
self.w = input1.new()
self.w32 = input1.new()
self._outputs = input1.new()
buffer = input1.new()
_idx = self._new_idx(input1)
torch.mul(buffer, input1, input2)
torch.sum(self.w1, buffer, 1)
epsilon = 1e-12
torch.mul(buffer, input1, input1)
torch.sum(self.w22, buffer, 1).add_(epsilon)
self._outputs.resize_as_(self.w22).fill_(1)
torch.div(self.w22, self._outputs, self.w22)
self.w.resize_as_(self.w22).copy_(self.w22)
torch.mul(buffer, input2, input2)
torch.sum(self.w32, buffer, 1).add_(epsilon)
torch.div(self.w32, self._outputs, self.w32)
self.w.mul_(self.w32)
self.w.sqrt_()
torch.mul(self._outputs, self.w1, self.w)
self._outputs = self._outputs.select(1, 0)
torch.eq(_idx, y, -1)
self._outputs[_idx] = self._outputs[_idx].add_(-self.margin).cmax_(0)
torch.eq(_idx, y, 1)
self._outputs[_idx] = self._outputs[_idx].mul_(-1).add_(1)
output = self._outputs.sum()
if self.size_average:
output = output / y.size(0)
self.save_for_backward(input1, input2, y)
return input1.new((output,))
评论列表
文章目录