def wgan_loss(x, gz, discriminator, beta=10.0):
"""Improved Wasserstein GAN loss.
Args:
x: Batch of real samples.
gz: Batch of generated samples.
discriminator: Discriminator function.
beta: Regualarizer factor.
Returns:
d_loss: Discriminator loss.
g_loss: Generator loss.
"""
dx = discriminator(x)
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
dgz = discriminator(gz)
batch_size = tf.shape(x)[0]
alpha = tf.random_uniform([batch_size])
xhat = x * alpha + gz * (1 - alpha)
with tf.variable_scope(tf.get_variable_scope(), reuse=True):
dxhat = discriminator(xhat)
gnorm = tf.norm(tf.gradients(dxhat, xhat)[0])
d_loss = -tf.reduce_mean(dx - dgz - beta * tf.square(gnorm - 1))
g_loss = -tf.reduce_mean(dgz)
return d_loss, g_loss
评论列表
文章目录