def test_dump():
Xs, y = load_svmlight_file(datafile)
Xd = Xs.toarray()
# slicing a csr_matrix can unsort its .indices, so test that we sort
# those correctly
Xsliced = Xs[np.arange(Xs.shape[0])]
for X in (Xs, Xd, Xsliced):
for zero_based in (True, False):
for dtype in [np.float32, np.float64, np.int32]:
f = BytesIO()
# we need to pass a comment to get the version info in;
# LibSVM doesn't grok comments so they're not put in by
# default anymore.
dump_svmlight_file(X.astype(dtype), y, f, comment="test",
zero_based=zero_based)
f.seek(0)
comment = f.readline()
try:
comment = str(comment, "utf-8")
except TypeError: # fails in Python 2.x
pass
assert_in("scikit-learn %s" % sklearn.__version__, comment)
comment = f.readline()
try:
comment = str(comment, "utf-8")
except TypeError: # fails in Python 2.x
pass
assert_in(["one", "zero"][zero_based] + "-based", comment)
X2, y2 = load_svmlight_file(f, dtype=dtype,
zero_based=zero_based)
assert_equal(X2.dtype, dtype)
assert_array_equal(X2.sorted_indices().indices, X2.indices)
if dtype == np.float32:
assert_array_almost_equal(
# allow a rounding error at the last decimal place
Xd.astype(dtype), X2.toarray(), 4)
else:
assert_array_almost_equal(
# allow a rounding error at the last decimal place
Xd.astype(dtype), X2.toarray(), 15)
assert_array_equal(y, y2)
评论列表
文章目录