def updateOutput(self, input, target):
# - log(input) * target - log(1 - input) * (1 - target)
if input.nelement() != target.nelement():
raise RuntimeError("input and target size mismatch")
if self.buffer is None:
self.buffer = input.new()
buffer = self.buffer
weights = self.weights
buffer.resize_as_(input)
if weights is not None and target.dim() != 1:
weights = self.weights.view(1, target.size(1)).expand_as(target)
# log(input) * target
torch.add(input, self.eps, out=buffer).log_()
if weights is not None:
buffer.mul_(weights)
target_1d = target.contiguous().view(-1)
# don't save a 1-d view of buffer: it should already be contiguous, and it's
# used as non-1d tensor later.
output = torch.dot(target_1d, buffer.contiguous().view(-1))
# log(1 - input) * (1 - target)
torch.mul(input, -1, out=buffer).add_(1 + self.eps).log_()
if weights is not None:
buffer.mul_(weights)
output = output + torch.sum(buffer)
output = output - torch.dot(target_1d, buffer.contiguous().view(-1))
if self.sizeAverage:
output = output / input.nelement()
self.output = - output
return self.output
评论列表
文章目录