test_vbn.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:tefla 作者: openAGI 项目源码 文件源码
def test_statistics(self):
        """Check that `_statistics` gives the same result as `nn.moments`."""
        tf.set_random_seed(1234)

        tensors = tf.random_normal([4, 5, 7, 3])
        for axes in [(3), (0, 2), (1, 2, 3)]:
            vb_mean, mean_sq = virtual_batchnorm._statistics(tensors, axes)
            mom_mean, mom_var = tf.nn.moments(tensors, axes)
            vb_var = mean_sq - tf.square(vb_mean)

            with self.test_session(use_gpu=True) as sess:
                vb_mean_np, vb_var_np, mom_mean_np, mom_var_np = sess.run([
                    vb_mean, vb_var, mom_mean, mom_var])

            self.assertAllClose(mom_mean_np, vb_mean_np)
            self.assertAllClose(mom_var_np, vb_var_np)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号