def test_discriminator_loss_with_placeholder_for_logits(self):
logits = tf.placeholder(tf.float32, shape=(None, 4))
logits2 = tf.placeholder(tf.float32, shape=(None, 4))
real_weights = tf.ones_like(logits, dtype=tf.float32)
generated_weights = tf.ones_like(logits, dtype=tf.float32)
loss = self._d_loss_fn(
logits, logits2, real_weights=real_weights,
generated_weights=generated_weights)
with self.test_session() as sess:
loss = sess.run(loss,
feed_dict={
logits: [self._discriminator_real_outputs_np],
logits2: [self._discriminator_gen_outputs_np],
})
self.assertAlmostEqual(self._expected_d_loss, loss, 5)
评论列表
文章目录