def test_trace_sqrt_product_value(self):
"""Test that `trace_sqrt_product` gives the correct value."""
np.random.seed(0)
# Make num_examples > num_features to ensure scipy's sqrtm function
# doesn't return a complex matrix.
test_pool_real_a = np.float32(np.random.randn(512, 256))
test_pool_gen_a = np.float32(np.random.randn(512, 256))
cov_real = np.cov(test_pool_real_a, rowvar=False)
cov_gen = np.cov(test_pool_gen_a, rowvar=False)
trace_sqrt_prod_op = _run_with_mock(gan_metrics.trace_sqrt_product,
cov_real, cov_gen)
with self.test_session() as sess:
# trace_sqrt_product: tsp
actual_tsp = sess.run(trace_sqrt_prod_op)
expected_tsp = _expected_trace_sqrt_product(cov_real, cov_gen)
self.assertAllClose(actual_tsp, expected_tsp, 0.01)
评论列表
文章目录