def compute_losses(self, images, wrong_images, fake_images, embeddings):
real_logit = self.model.get_discriminator(images, embeddings)
wrong_logit = self.model.get_discriminator(wrong_images, embeddings)
fake_logit = self.model.get_discriminator(fake_images, embeddings)
real_d_loss =\
tf.nn.sigmoid_cross_entropy_with_logits(real_logit,
tf.ones_like(real_logit))
real_d_loss = tf.reduce_mean(real_d_loss)
wrong_d_loss =\
tf.nn.sigmoid_cross_entropy_with_logits(wrong_logit,
tf.zeros_like(wrong_logit))
wrong_d_loss = tf.reduce_mean(wrong_d_loss)
fake_d_loss =\
tf.nn.sigmoid_cross_entropy_with_logits(fake_logit,
tf.zeros_like(fake_logit))
fake_d_loss = tf.reduce_mean(fake_d_loss)
if cfg.TRAIN.B_WRONG:
discriminator_loss =\
real_d_loss + (wrong_d_loss + fake_d_loss) / 2.
self.log_vars.append(("d_loss_wrong", wrong_d_loss))
else:
discriminator_loss = real_d_loss + fake_d_loss
self.log_vars.append(("d_loss_real", real_d_loss))
self.log_vars.append(("d_loss_fake", fake_d_loss))
generator_loss = \
tf.nn.sigmoid_cross_entropy_with_logits(fake_logit,
tf.ones_like(fake_logit))
generator_loss = tf.reduce_mean(generator_loss)
return discriminator_loss, generator_loss
trainer.py 文件源码
python
阅读 29
收藏 0
点赞 0
评论 0
评论列表
文章目录