dev_train.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:PaintsPytorch 作者: orashi 项目源码 文件源码
def calc_gradient_penalty(netD, real_data, fake_data):
    # print "real_data: ", real_data.size(), fake_data.size()
    alpha = torch.rand(opt.batchSize, 1, 1, 1)
    # alpha = alpha.expand(opt.batchSize, real_data.nelement() / opt.batchSize).contiguous().view(opt.batchSize, 3, 64,
    #                                                                                             64)
    alpha = alpha.cuda() if opt.cuda else alpha

    interpolates = alpha * real_data + ((1 - alpha) * fake_data)

    if opt.cuda:
        interpolates = interpolates.cuda()
    interpolates = Variable(interpolates, requires_grad=True)

    disc_interpolates = netD(interpolates)

    gradients = grad(outputs=disc_interpolates, inputs=interpolates,
                     grad_outputs=torch.ones(disc_interpolates.size()).cuda() if opt.cuda else torch.ones(
                         disc_interpolates.size()),
                     create_graph=True, retain_graph=True, only_inputs=True)[0]

    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * opt.gpW
    return gradient_penalty
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号