def make_train_test(Y, ntest, seed = None):
if type(Y) not in [sp.sparse.coo.coo_matrix, sp.sparse.csr.csr_matrix, sp.sparse.csc.csc_matrix]:
raise ValueError("Unsupported Y type: %s" + type(Y))
if not isinstance(ntest, numbers.Real) or ntest < 0:
raise ValueError("ntest has to be a non-negative number (number or ratio of test samples).")
Y = Y.tocoo(copy = False)
if ntest < 1:
ntest = Y.nnz * ntest
if seed is not None:
np.random.seed(seed)
ntest = int(round(ntest))
rperm = np.random.permutation(Y.nnz)
train = rperm[ntest:]
test = rperm[0:ntest]
Ytrain = sp.sparse.coo_matrix( (Y.data[train], (Y.row[train], Y.col[train])), shape=Y.shape )
Ytest = sp.sparse.coo_matrix( (Y.data[test], (Y.row[test], Y.col[test])), shape=Y.shape )
return Ytrain, Ytest
评论列表
文章目录