def backward(self, grad_output):
v1, v2, y = self.saved_tensors
buffer = v1.new()
_idx = self._new_idx(v1)
gw1 = grad_output.new()
gw2 = grad_output.new()
gw1.resize_as_(v1).copy_(v2)
gw2.resize_as_(v1).copy_(v1)
torch.mul(buffer, self.w1, self.w22)
gw1.addcmul_(-1, buffer.expand_as(v1), v1)
gw1.mul_(self.w.expand_as(v1))
torch.mul(buffer, self.w1, self.w32)
gw2.addcmul_(-1, buffer.expand_as(v1), v2)
gw2.mul_(self.w.expand_as(v1))
torch.le(_idx, self._outputs, 0)
_idx = _idx.view(-1, 1).expand(gw1.size())
gw1[_idx] = 0
gw2[_idx] = 0
torch.eq(_idx, y, 1)
_idx = _idx.view(-1, 1).expand(gw2.size())
gw1[_idx] = gw1[_idx].mul_(-1)
gw2[_idx] = gw2[_idx].mul_(-1)
if self.size_average:
gw1.div_(y.size(0))
gw2.div_(y.size(0))
if grad_output[0] != 1:
gw1.mul_(grad_output)
gw2.mul_(grad_output)
return gw1, gw2, None
评论列表
文章目录