test_slinalg.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:Theano-Deep-learning 作者: GeekLiB 项目源码 文件源码
def test_perform(self):
        if not imported_scipy:
            raise SkipTest('kron tests need the scipy package to be installed')

        for shp0 in [(2,), (2, 3), (2, 3, 4), (2, 3, 4, 5)]:
            x = tensor.tensor(dtype='floatX',
                              broadcastable=(False,) * len(shp0))
            a = numpy.asarray(self.rng.rand(*shp0)).astype(config.floatX)
            for shp1 in [(6,), (6, 7), (6, 7, 8), (6, 7, 8, 9)]:
                if len(shp0) + len(shp1) == 2:
                    continue
                y = tensor.tensor(dtype='floatX',
                                  broadcastable=(False,) * len(shp1))
                f = function([x, y], kron(x, y))
                b = self.rng.rand(*shp1).astype(config.floatX)
                out = f(a, b)
                # Newer versions of scipy want 4 dimensions at least,
                # so we have to add a dimension to a and flatten the result.
                if len(shp0) + len(shp1) == 3:
                    scipy_val = scipy.linalg.kron(
                        a[numpy.newaxis, :], b).flatten()
                else:
                    scipy_val = scipy.linalg.kron(a, b)
                utt.assert_allclose(out, scipy_val)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号