def test_statistics(self):
"""Check that `_statistics` gives the same result as `nn.moments`."""
tf.set_random_seed(1234)
tensors = tf.random_normal([4, 5, 7, 3])
for axes in [(3), (0, 2), (1, 2, 3)]:
vb_mean, mean_sq = virtual_batchnorm._statistics(tensors, axes)
mom_mean, mom_var = tf.nn.moments(tensors, axes)
vb_var = mean_sq - tf.square(vb_mean)
with self.test_session(use_gpu=True) as sess:
vb_mean_np, vb_var_np, mom_mean_np, mom_var_np = sess.run([
vb_mean, vb_var, mom_mean, mom_var])
self.assertAllClose(mom_mean_np, vb_mean_np)
self.assertAllClose(mom_var_np, vb_var_np)
评论列表
文章目录