test_basic.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:Theano-Deep-learning 作者: GeekLiB 项目源码 文件源码
def test_infer_shape(self):
        for shp1, shp2 in [
            ((5, 4), (7, 4)),
            ((1, 4), (7, 4)),
            ((5, 1), (7, 4)),
            ((5, 4), (1, 4)),
            ((5, 4), (7, 1)),

            ((5, 4), (4,)),
            ((1, 4), (4,)),
            ((5, 1), (4,)),
            ((5, 4), (1,)),

            ((4,), (5, 4)),
            ((1,), (5, 4)),
            ((4,), (1, 4)),
            ((4,), (3, 1)),

            ((4,), (4,)),
            ((1,), (4,)),
            ((4,), (1,)),
            ((1,), (1,)),
        ]:
            a = tensor.tensor(dtype='int32',
                              broadcastable=[n == 1 for n in shp1])
            c = tensor.tensor(dtype='float32',
                              broadcastable=[n == 1 for n in shp2])
            A = numpy.asarray(numpy.random.rand(*shp1) * shp2[0], dtype='int32')
            C = numpy.asarray(numpy.random.rand(*shp2) * shp2[0], dtype='float32')
            self._compile_and_check([a, c],  # theano.function inputs
                                    [self.op(a, c)],  # theano.function outputs
                                    # Always use not square matrix!
                                    # inputs data
                                    [A, C],
                                    # Op that should be removed from the graph.
                                    self.op_class)

# Disabled as it isn't implemented.
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号