models.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:CausalGAN 作者: mkocaoglu 项目源码 文件源码
def Grad_Penalty(real_data,fake_data,Discriminator,config):
    '''
    Implemention from "Improved training of Wasserstein"
    Interpolation based estimation of the gradient of the discriminator.
    Used to penalize the derivative rather than explicitly constrain lipschitz.
    '''
    batch_size=config.batch_size
    LAMBDA=config.lambda_W
    n_hidden=config.critic_hidden_size
    alpha = tf.random_uniform([batch_size,1],0.,1.)
    interpolates = alpha*real_data + ((1-alpha)*fake_data)#Could do more if not fixed batch_size
    disc_interpolates = Discriminator(interpolates,batch_size,n_hidden=n_hidden,config=config, reuse=True)[1]#logits
    gradients = tf.gradients(disc_interpolates,[interpolates])[0]#orig
    slopes = tf.sqrt(tf.reduce_sum(tf.square(gradients),
                           reduction_indices=[1]))
    gradient_penalty = tf.reduce_mean((slopes-1)**2)
    grad_cost = LAMBDA*gradient_penalty
    return grad_cost,slopes
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号