BCECriterion.py 文件源码

python
阅读 48 收藏 0 点赞 0 评论 0

项目:pytorch 作者: ezyang 项目源码 文件源码
def updateOutput(self, input, target):
        # - log(input) * target - log(1 - input) * (1 - target)
        if input.nelement() != target.nelement():
            raise RuntimeError("input and target size mismatch")

        if self.buffer is None:
            self.buffer = input.new()

        buffer = self.buffer
        weights = self.weights

        buffer.resize_as_(input)

        if weights is not None and target.dim() != 1:
            weights = self.weights.view(1, target.size(1)).expand_as(target)

        # log(input) * target
        torch.add(input, self.eps, out=buffer).log_()
        if weights is not None:
            buffer.mul_(weights)

        target_1d = target.contiguous().view(-1)
        # don't save a 1-d view of buffer: it should already be contiguous, and it's
        # used as non-1d tensor later.
        output = torch.dot(target_1d, buffer.contiguous().view(-1))

        # log(1 - input) * (1 - target)
        torch.mul(input, -1, out=buffer).add_(1 + self.eps).log_()
        if weights is not None:
            buffer.mul_(weights)

        output = output + torch.sum(buffer)
        output = output - torch.dot(target_1d, buffer.contiguous().view(-1))

        if self.sizeAverage:
            output = output / input.nelement()

        self.output = - output

        return self.output
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号