loss.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:pytorch 作者: ezyang 项目源码 文件源码
def backward(ctx, grad_output):
        v1, v2, y = ctx.saved_tensors

        buffer = v1.new()
        _idx = v1.new().byte()

        gw1 = grad_output.new()
        gw2 = grad_output.new()
        gw1.resize_as_(v1).copy_(v2)
        gw2.resize_as_(v1).copy_(v1)

        torch.mul(ctx.w1, ctx.w22, out=buffer)
        gw1.addcmul_(-1, buffer.expand_as(v1), v1)
        gw1.mul_(ctx.w.expand_as(v1))

        torch.mul(ctx.w1, ctx.w32, out=buffer)
        gw2.addcmul_(-1, buffer.expand_as(v1), v2)
        gw2.mul_(ctx.w.expand_as(v1))

        torch.le(ctx._outputs, 0, out=_idx)
        _idx = _idx.view(-1, 1).expand(gw1.size())
        gw1[_idx] = 0
        gw2[_idx] = 0

        torch.eq(y, 1, out=_idx)
        _idx = _idx.view(-1, 1).expand(gw2.size())
        gw1[_idx] = gw1[_idx].mul_(-1)
        gw2[_idx] = gw2[_idx].mul_(-1)

        if ctx.size_average:
            gw1.div_(y.size(0))
            gw2.div_(y.size(0))

        grad_output_val = grad_output[0]
        if grad_output_val != 1:
            gw1.mul_(grad_output_val)
            gw2.mul_(grad_output_val)

        return gw1, gw2, None, None, None
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号