test_basic.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:Theano-Deep-learning 作者: GeekLiB 项目源码 文件源码
def test2_valid_neg(self):
        n = as_tensor_variable(rand(2, 3))
        v, i = eval_outputs(max_and_argmax(n, -1))
        assert i.dtype == 'int64'
        self.assertTrue(v.shape == (2,))
        self.assertTrue(i.shape == (2,))
        self.assertTrue(numpy.all(v == numpy.max(n.value, -1)))
        self.assertTrue(numpy.all(i == numpy.argmax(n.value, -1)))
        v, i = eval_outputs(max_and_argmax(n, -2))
        assert i.dtype == 'int64'
        self.assertTrue(v.shape == (3,))
        self.assertTrue(i.shape == (3,))
        self.assertTrue(numpy.all(v == numpy.max(n.value, -2)))
        self.assertTrue(numpy.all(i == numpy.argmax(n.value, -2)))
        v = eval_outputs(max_and_argmax(n, -1)[0].shape)
        assert v == (2)
        v = eval_outputs(max_and_argmax(n, -2)[0].shape)
        assert v == (3)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号