loss.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:pytorch-dist 作者: apaszke 项目源码 文件源码
def forward(self, input, target):
        buffer = input.new()
        buffer.resize_as_(input).copy_(input)
        buffer[torch.eq(target, -1.)] = 0
        output = buffer.sum()

        buffer.fill_(self.margin).add_(-1, input)
        buffer.cmax_(0)
        buffer[torch.eq(target, 1.)] = 0
        output += buffer.sum()

        if self.size_average:
            output = output / input.nelement()

        self.save_for_backward(input, target)
        return input.new((output,))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号