test_basic.py 文件源码

python
阅读 35 收藏 0 点赞 0 评论 0

项目:Theano-Deep-learning 作者: GeekLiB 项目源码 文件源码
def test_join_matrix_ints(self):
        if "float32" in self.shared.__name__:
            raise SkipTest(
                "The shared variable constructor"
                " need to support other dtype then float32")
        # Test mixed dtype. There was a bug that caused crash in the past.
        av = numpy.array([[1, 2, 3], [4, 5, 6]], dtype='int8')
        bv = numpy.array([[7], [8]], dtype='int32')
        a = self.shared(av)
        b = as_tensor_variable(bv)
        s = join(1, a, b)
        want = numpy.array([[1, 2, 3, 7], [4, 5, 6, 8]], dtype='float32')
        out = self.eval_outputs_and_check_join([s])
        self.assertTrue((out == want).all())

        assert (numpy.asarray(grad(s.sum(), b).eval()) == 0).all()
        assert (numpy.asarray(grad(s.sum(), a).eval()) == 0).all()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号