test_svmlight_format.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:Parallel-SGD 作者: angadgill 项目源码 文件源码
def test_dump():
    Xs, y = load_svmlight_file(datafile)
    Xd = Xs.toarray()

    # slicing a csr_matrix can unsort its .indices, so test that we sort
    # those correctly
    Xsliced = Xs[np.arange(Xs.shape[0])]

    for X in (Xs, Xd, Xsliced):
        for zero_based in (True, False):
            for dtype in [np.float32, np.float64, np.int32]:
                f = BytesIO()
                # we need to pass a comment to get the version info in;
                # LibSVM doesn't grok comments so they're not put in by
                # default anymore.
                dump_svmlight_file(X.astype(dtype), y, f, comment="test",
                                   zero_based=zero_based)
                f.seek(0)

                comment = f.readline()
                try:
                    comment = str(comment, "utf-8")
                except TypeError:  # fails in Python 2.x
                    pass

                assert_in("scikit-learn %s" % sklearn.__version__, comment)

                comment = f.readline()
                try:
                    comment = str(comment, "utf-8")
                except TypeError:  # fails in Python 2.x
                    pass

                assert_in(["one", "zero"][zero_based] + "-based", comment)

                X2, y2 = load_svmlight_file(f, dtype=dtype,
                                            zero_based=zero_based)
                assert_equal(X2.dtype, dtype)
                assert_array_equal(X2.sorted_indices().indices, X2.indices)
                if dtype == np.float32:
                    assert_array_almost_equal(
                        # allow a rounding error at the last decimal place
                        Xd.astype(dtype), X2.toarray(), 4)
                else:
                    assert_array_almost_equal(
                        # allow a rounding error at the last decimal place
                        Xd.astype(dtype), X2.toarray(), 15)
                assert_array_equal(y, y2)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号