test_opt.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:Theano-Deep-learning 作者: GeekLiB 项目源码 文件源码
def test_broadcast2(self):
        # test switch(cst, vector, matrix)

        # This case is not optimized for now.
        x = theano.tensor.vector('x', dtype='int32')
        y = theano.tensor.matrix('y', dtype='int64')
        z = theano.tensor.switch(1, x, y)
        f = theano.function([x, y], z, mode=self.mode)
        assert len([node.op for node in f.maker.fgraph.toposort() if
                    isinstance(node.op, theano.tensor.Elemwise) and
                    not isinstance(node.op.scalar_op, theano.scalar.basic.Cast)]) == 0
        vx = numpy.array([4, 5, 6], dtype='int32')
        vy = numpy.array([[7, 8, 9], [10, 11, 12]], dtype='int64')
        assert numpy.all(f(vx, vy) == vx)

        z = theano.tensor.switch(0, x, y)
        f = theano.function([x, y], z, mode=self.mode)
        assert len([node.op for node in f.maker.fgraph.toposort() if
                    isinstance(node.op, theano.tensor.Elemwise)]) == 0
        vx = numpy.array([4, 5, 6], dtype='int32')
        vy = numpy.array([[7, 8, 9], [10, 11, 12]], dtype='int64')
        assert numpy.all(f(vx, vy) == vy)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号