test_basic.py 文件源码

python
阅读 42 收藏 0 点赞 0 评论 0

项目:Theano-Deep-learning 作者: GeekLiB 项目源码 文件源码
def test_remove0(self):
        configs = [
            # structure type, numpy matching class
            ('csc', scipy.sparse.csc_matrix),
            ('csr', scipy.sparse.csr_matrix), ]

        for format, matrix_class in configs:
            for zero, unsor in [(True, True), (True, False),
                              (False, True), (False, False)]:
                (x,), (mat,) = sparse_random_inputs(format, (6, 8),
                                            out_dtype=config.floatX,
                                            explicit_zero=zero,
                                            unsorted_indices=unsor)
                assert 0 in mat.data or not zero
                assert not mat.has_sorted_indices or not unsor

                # the In thingy has to be there because theano has as rule not
                # to optimize inputs
                f = theano.function([theano.In(x, borrow=True, mutable=True)],
                                    Remove0()(x))

                # assert optimization local_inplace_remove0 is applied in
                # modes with optimization
                if theano.config.mode not in ['FAST_COMPILE']:
                    # list of apply nodes in the optimized graph.
                    nodes = f.maker.fgraph.toposort()
                    # Check there isn't any Remove0 instance not inplace.
                    assert not any([isinstance(node.op, Remove0) and
                                    not node.op.inplace for node in nodes]), (
                           'Inplace optimization should have been applied')
                    # Check there is at least one Remove0 inplace.
                    assert any([isinstance(node.op, Remove0) and node.op.inplace
                                for node in nodes])
                # checking
                # makes sense to change its name
                target = mat
                result = f(mat)
                mat.eliminate_zeros()
                msg = 'Matrices sizes differ. Have zeros been removed ?'
                assert result.size == target.size, msg
                if unsor:
                    assert not result.has_sorted_indices
                    assert not target.has_sorted_indices
                else:
                    assert result.has_sorted_indices
                    assert target.has_sorted_indices
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号