loss.py 文件源码

python
阅读 36 收藏 0 点赞 0 评论 0

项目:pytorch 作者: pytorch 项目源码 文件源码
def forward(ctx, input, target, margin, size_average):
        ctx.margin = margin
        ctx.size_average = size_average
        buffer = input.new()
        buffer.resize_as_(input).copy_(input)
        buffer[torch.eq(target, -1.)] = 0
        output = buffer.sum()

        buffer.fill_(ctx.margin).add_(-1, input)
        buffer.clamp_(min=0)
        buffer[torch.eq(target, 1.)] = 0
        output += buffer.sum()

        if ctx.size_average:
            output = output / input.nelement()

        ctx.save_for_backward(input, target)
        return input.new((output,))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号