utils.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:GAN-Zoo 作者: corenel 项目源码 文件源码
def calc_gradient_penalty(D, real_data, fake_data):
    """Calculatge gradient penalty for WGAN-GP."""
    alpha = torch.rand(params.batch_size, 1)
    alpha = alpha.expand(real_data.size())
    alpha = make_cuda(alpha)

    interpolates = make_variable(alpha * real_data + ((1 - alpha) * fake_data))
    interpolates.requires_grad = True

    disc_interpolates = D(interpolates)

    gradients = grad(outputs=disc_interpolates,
                     inputs=interpolates,
                     grad_outputs=make_cuda(
                         torch.ones(disc_interpolates.size())),
                     create_graph=True,
                     retain_graph=True,
                     only_inputs=True)[0]

    gradient_penalty = params.penalty_lambda * \
        ((gradients.norm(2, dim=1) - 1) ** 2).mean()

    return gradient_penalty
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号