def test_same_as_batchnorm(self):
"""Check that batch norm on set X is the same as ref of X / y on `y`."""
tf.set_random_seed(1234)
num_examples = 4
examples = [tf.random_normal([5, 7, 3]) for _ in
range(num_examples)]
# Get the result of the opensource batch normalization.
batch_normalized = tf.layers.batch_normalization(
tf.stack(examples), training=True)
for i in range(num_examples):
examples_except_i = tf.stack(examples[:i] + examples[i + 1:])
# Get the result of VBN's batch normalization.
vbn = virtual_batchnorm.VBN(examples_except_i)
vb_normed = tf.squeeze(
vbn(tf.expand_dims(examples[i], [0])), [0])
with self.test_session(use_gpu=True) as sess:
tf.global_variables_initializer().run()
bn_np, vb_np = sess.run([batch_normalized, vb_normed])
self.assertAllClose(bn_np[i, ...], vb_np)
评论列表
文章目录