losses.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:TensorflowFramework 作者: vahidk 项目源码 文件源码
def wgan_loss(x, gz, discriminator, beta=10.0):
  """Improved Wasserstein GAN loss.

  Args:
    x: Batch of real samples.
    gz: Batch of generated samples.
    discriminator: Discriminator function.
    beta: Regualarizer factor.
  Returns:
    d_loss: Discriminator loss.
    g_loss: Generator loss.
  """
  dx = discriminator(x)
  with tf.variable_scope(tf.get_variable_scope(), reuse=True):
    dgz = discriminator(gz)
  batch_size = tf.shape(x)[0]
  alpha = tf.random_uniform([batch_size])
  xhat = x * alpha + gz * (1 - alpha)
  with tf.variable_scope(tf.get_variable_scope(), reuse=True):
    dxhat = discriminator(xhat)
  gnorm = tf.norm(tf.gradients(dxhat, xhat)[0])
  d_loss = -tf.reduce_mean(dx - dgz - beta * tf.square(gnorm - 1))
  g_loss = -tf.reduce_mean(dgz)
  return d_loss, g_loss
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号