def test_frechet_classifier_distance_value(self):
"""Test that `frechet_classifier_distance` gives the correct value."""
np.random.seed(0)
# Make num_examples > num_features to ensure scipy's sqrtm function
# doesn't return a complex matrix.
test_pool_real_a = np.float32(np.random.randn(512, 256))
test_pool_gen_a = np.float32(np.random.randn(512, 256))
fid_op = _run_with_mock(gan_metrics.frechet_classifier_distance,
test_pool_real_a, test_pool_gen_a,
classifier_fn=lambda x: x)
with self.test_session() as sess:
actual_fid = sess.run(fid_op)
expected_fid = _expected_fid(test_pool_real_a, test_pool_gen_a)
self.assertAllClose(expected_fid, actual_fid, 0.0001)
评论列表
文章目录