test_gan_metrics.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:tefla 作者: openAGI 项目源码 文件源码
def test_frechet_classifier_distance_value(self):
        """Test that `frechet_classifier_distance` gives the correct value."""
        np.random.seed(0)

        # Make num_examples > num_features to ensure scipy's sqrtm function
        # doesn't return a complex matrix.
        test_pool_real_a = np.float32(np.random.randn(512, 256))
        test_pool_gen_a = np.float32(np.random.randn(512, 256))

        fid_op = _run_with_mock(gan_metrics.frechet_classifier_distance,
                                test_pool_real_a, test_pool_gen_a,
                                classifier_fn=lambda x: x)

        with self.test_session() as sess:
            actual_fid = sess.run(fid_op)

        expected_fid = _expected_fid(test_pool_real_a, test_pool_gen_a)

        self.assertAllClose(expected_fid, actual_fid, 0.0001)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号