loss.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:pytorch 作者: pytorch 项目源码 文件源码
def forward(ctx, input, target, grad_output, margin, size_average):
        ctx.margin = margin
        ctx.size_average = size_average
        ctx.save_for_backward(input, target, grad_output)
        grad_input = input.new().resize_as_(input).copy_(target)
        grad_input[torch.mul(torch.eq(target, -1), torch.gt(input, ctx.margin))] = 0

        if ctx.size_average:
            grad_input.mul_(1. / input.nelement())

        if grad_output[0] != 1:
            grad_input.mul_(grad_output[0])

        return grad_input
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号