test_gan_losses.py 文件源码

python
阅读 36 收藏 0 点赞 0 评论 0

项目:tefla 作者: openAGI 项目源码 文件源码
def test_discriminator_loss_with_placeholder_for_logits(self):
        logits = tf.placeholder(tf.float32, shape=(None, 4))
        logits2 = tf.placeholder(tf.float32, shape=(None, 4))
        real_weights = tf.ones_like(logits, dtype=tf.float32)
        generated_weights = tf.ones_like(logits, dtype=tf.float32)

        loss = self._d_loss_fn(
            logits, logits2, real_weights=real_weights,
            generated_weights=generated_weights)

        with self.test_session() as sess:
            loss = sess.run(loss,
                            feed_dict={
                                logits: [self._discriminator_real_outputs_np],
                                logits2: [self._discriminator_gen_outputs_np],
                            })
            self.assertAlmostEqual(self._expected_d_loss, loss, 5)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号