test_vbn.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:tefla 作者: openAGI 项目源码 文件源码
def test_same_as_batchnorm(self):
        """Check that batch norm on set X is the same as ref of X / y on `y`."""
        tf.set_random_seed(1234)

        num_examples = 4
        examples = [tf.random_normal([5, 7, 3]) for _ in
                    range(num_examples)]

        # Get the result of the opensource batch normalization.
        batch_normalized = tf.layers.batch_normalization(
            tf.stack(examples), training=True)

        for i in range(num_examples):
            examples_except_i = tf.stack(examples[:i] + examples[i + 1:])
            # Get the result of VBN's batch normalization.
            vbn = virtual_batchnorm.VBN(examples_except_i)
            vb_normed = tf.squeeze(
                vbn(tf.expand_dims(examples[i], [0])), [0])

            with self.test_session(use_gpu=True) as sess:
                tf.global_variables_initializer().run()
                bn_np, vb_np = sess.run([batch_normalized, vb_normed])
            self.assertAllClose(bn_np[i, ...], vb_np)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号