BCECriterion.py 文件源码

python
阅读 46 收藏 0 点赞 0 评论 0

项目:pytorch 作者: tylergenter 项目源码 文件源码
def updateOutput(self, input, target):
        # - log(input) * target - log(1 - input) * (1 - target)
        if input.nelement() != target.nelement():
            raise RuntimeError("input and target size mismatch")

        if self.buffer is None:
            self.buffer = input.new()

        buffer = self.buffer
        weights = self.weights

        buffer.resize_as_(input)

        if weights is not None and target.dim() != 1:
            weights = self.weights.view(1, target.size(1)).expand_as(target)

        # log(input) * target
        torch.add(input, self.eps, out=buffer).log_()
        if weights is not None:
            buffer.mul_(weights)

        output = torch.dot(target, buffer)

        # log(1 - input) * (1 - target)
        torch.mul(input, -1, out=buffer).add_(1 + self.eps).log_()
        if weights is not None:
            buffer.mul_(weights)

        output = output + torch.sum(buffer)
        output = output - torch.dot(target, buffer)

        if self.sizeAverage:
            output = output / input.nelement()

        self.output = - output

        return self.output
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号