test_opt.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:Theano-Deep-learning 作者: GeekLiB 项目源码 文件源码
def test_dot_allocs_0(self):
        v1 = tensor.vector('v1')
        v2 = tensor.vector('v2')
        m1 = tensor.matrix('m1')
        m2 = tensor.matrix('m2')
        vv2 = numpy.asarray([0, 1], dtype=theano.config.floatX)
        vm2 = numpy.asarray([[1, 2], [4, 5]],
                            dtype=theano.config.floatX)
        vv3 = numpy.asarray([0, 1, 2], dtype=theano.config.floatX)
        vm3 = numpy.asarray([[1, 2, 3], [4, 5, 6], [7, 8, 9]],
                            dtype=theano.config.floatX)
        for _e1 in [(v1, vv2, vv3), (m1, vm2, vm3)]:
            for _e2 in [(v2, vv2, vv3), (m2, vm2, vm3)]:
                for p in [0, 1]:
                    if p == 0:
                        e1 = tensor.zeros_like(_e1[0])
                        e2 = _e2[0]
                    else:
                        e1 = _e1[0]
                        e2 = tensor.zeros_like(_e2[0])
                    o = tensor.dot(e1, e2)
                    f = theano.function([_e1[0], _e2[0]], o, mode=self.mode)
                    f(_e1[1], _e2[1])
                    f(_e1[2], _e2[2])
                    assert numpy.all([not isinstance(n.op, tensor.Dot) for n in
                                      f.maker.fgraph.toposort()])

                    # test that we don't remove shape errors
                    self.assertRaises((ValueError, AssertionError), f,
                                      _e1[1], _e2[2])
                    self.assertRaises((ValueError, AssertionError), f,
                                      _e1[2], _e2[1])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号