def updateOutput(self, input, target):
# - log(input) * target - log(1 - input) * (1 - target)
if input.nelement() != target.nelement():
raise RuntimeError("input and target size mismatch")
if self.buffer is None:
self.buffer = input.new()
buffer = self.buffer
weights = self.weights
buffer.resize_as_(input)
if weights is not None and target.dim() != 1:
weights = self.weights.view(1, target.size(1)).expand_as(target)
# log(input) * target
torch.add(input, self.eps, out=buffer).log_()
if weights is not None:
buffer.mul_(weights)
output = torch.dot(target, buffer)
# log(1 - input) * (1 - target)
torch.mul(input, -1, out=buffer).add_(1 + self.eps).log_()
if weights is not None:
buffer.mul_(weights)
output = output + torch.sum(buffer)
output = output - torch.dot(target, buffer)
if self.sizeAverage:
output = output / input.nelement()
self.output = - output
return self.output
评论列表
文章目录