test_basic.py 文件源码

python
阅读 32 收藏 0 点赞 0 评论 0

项目:Theano-Deep-learning 作者: GeekLiB 项目源码 文件源码
def test_tri(self):
        def check(dtype, N, M_=None, k=0):
            # Theano does not accept None as a tensor.
            # So we must use a real value.
            M = M_
            # Currently DebugMode does not support None as inputs even if this is
            # allowed.
            if M is None and theano.config.mode in ['DebugMode', 'DEBUG_MODE']:
                M = N
            N_symb = tensor.iscalar()
            M_symb = tensor.iscalar()
            k_symb = tensor.iscalar()
            f = function([N_symb, M_symb, k_symb],
                        tri(N_symb, M_symb, k_symb, dtype=dtype))
            result = f(N, M, k)
            self.assertTrue(
                numpy.allclose(result, numpy.tri(N, M_, k, dtype=dtype)))
            self.assertTrue(result.dtype == numpy.dtype(dtype))
        for dtype in ALL_DTYPES:
            yield check, dtype, 3
            # M != N, k = 0
            yield check, dtype, 3, 5
            yield check, dtype, 5, 3
            # N == M, k != 0
            yield check, dtype, 3, 3, 1
            yield check, dtype, 3, 3, -1
            # N < M, k != 0
            yield check, dtype, 3, 5, 1
            yield check, dtype, 3, 5, -1
            # N > M, k != 0
            yield check, dtype, 5, 3, 1
            yield check, dtype, 5, 3, -1
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号