test_gan_metrics.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:tefla 作者: openAGI 项目源码 文件源码
def test_trace_sqrt_product_value(self):
        """Test that `trace_sqrt_product` gives the correct value."""
        np.random.seed(0)

        # Make num_examples > num_features to ensure scipy's sqrtm function
        # doesn't return a complex matrix.
        test_pool_real_a = np.float32(np.random.randn(512, 256))
        test_pool_gen_a = np.float32(np.random.randn(512, 256))

        cov_real = np.cov(test_pool_real_a, rowvar=False)
        cov_gen = np.cov(test_pool_gen_a, rowvar=False)

        trace_sqrt_prod_op = _run_with_mock(gan_metrics.trace_sqrt_product,
                                            cov_real, cov_gen)

        with self.test_session() as sess:
            # trace_sqrt_product: tsp
            actual_tsp = sess.run(trace_sqrt_prod_op)

        expected_tsp = _expected_trace_sqrt_product(cov_real, cov_gen)

        self.assertAllClose(actual_tsp, expected_tsp, 0.01)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号