HingeLoss.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:pytorch-geometric-gan 作者: lim0606 项目源码 文件源码
def forward(self, input, target):
        #
        input = input.view(-1)

        #
        assert input.dim() == target.dim()
        for i in range(input.dim()): 
            assert input.size(i) == target.size(i)

        #
        output = self.margin - torch.mul(target, input)

        #         
        if 'cuda' in input.data.type():
            mask = torch.cuda.FloatTensor(input.size()).zero_()
        else:
            mask = torch.FloatTensor(input.size()).zero_()
        mask = Variable(mask)
        mask[torch.gt(output, 0.0)] = 1.0

        #
        output = torch.mul(output, mask)

        # size average
        if self.size_average:
            output = torch.mul(output, 1.0 / input.nelement())

        # sum
        output = output.sum()

        # apply sign
        output = torch.mul(output, self.sign)
        return output
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号