test_basic.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:Theano-Deep-learning 作者: GeekLiB 项目源码 文件源码
def test_join_matrixV(self):
        """variable join axis"""
        v = numpy.array([[.1, .2, .3], [.4, .5, .6]], dtype=self.floatX)
        a = self.shared(v)
        b = as_tensor_variable(v)
        ax = lscalar()
        s = join(ax, a, b)

        f = inplace_func([ax], [s], mode=self.mode)
        topo = f.maker.fgraph.toposort()
        assert [True for node in topo
                if isinstance(node.op, type(self.join_op))]

        want = numpy.array([[.1, .2, .3], [.4, .5, .6],
                            [.1, .2, .3], [.4, .5, .6]])
        got = f(0)
        assert numpy.allclose(got, want)

        want = numpy.array([[.1, .2, .3, .1, .2, .3],
                            [.4, .5, .6, .4, .5, .6]])
        got = f(1)
        assert numpy.allclose(got, want)

        utt.verify_grad(lambda a, b: join(0, a, b), [v, 2 * v], mode=self.mode)
        utt.verify_grad(lambda a, b: join(1, a, b), [v, 2 * v], mode=self.mode)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号