test_utils.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:zhusuan 作者: thu-ml 项目源码 文件源码
def test_explicit_broadcast(self):
        with self.test_session(use_gpu=True):
            def _test_func(a_shape, b_shape, target_shape):
                a = tf.ones(a_shape)
                b = tf.ones(b_shape)
                a, b = explicit_broadcast(a, b, 'a', 'b')
                self.assertEqual(a.eval().shape, b.eval().shape)
                self.assertEqual(a.eval().shape, target_shape)

            _test_func((5, 4), (1,), (5, 4))
            _test_func((5, 4), (4,), (5, 4))
            _test_func((2, 3, 5), (2, 1, 5), (2, 3, 5))
            _test_func((2, 3, 5), (3, 5), (2, 3, 5))
            _test_func((2, 3, 5), (3, 1), (2, 3, 5))

            with self.assertRaisesRegexp(ValueError, "cannot broadcast"):
                _test_func((3,), (4,), None)
            with self.assertRaisesRegexp(ValueError, "cannot broadcast"):
                _test_func((2, 1), (2, 4, 3), None)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号