def load_20newsgroup_vectorized(folder=SCIKIT_LEARN_DATA, one_hot=True, partitions_proportions=None,
shuffle=False, binary_problem=False, as_tensor=True, minus_value=-1.):
data_train = sk_dt.fetch_20newsgroups_vectorized(data_home=folder, subset='train')
data_test = sk_dt.fetch_20newsgroups_vectorized(data_home=folder, subset='test')
X_train = data_train.data
X_test = data_test.data
y_train = data_train.target
y_test = data_test.target
if binary_problem:
y_train[data_train.target < 10] = minus_value
y_train[data_train.target >= 10] = 1.
y_test[data_test.target < 10] = minus_value
y_test[data_test.target >= 10] = 1.
if one_hot:
y_train = to_one_hot_enc(y_train)
y_test = to_one_hot_enc(y_test)
# if shuffle and sk_shuffle:
# xtr = X_train.tocoo()
# xts = X_test.tocoo()
d_train = Dataset(data=X_train,
target=y_train, info={'target names': data_train.target_names})
d_test = Dataset(data=X_test,
target=y_test, info={'target names': data_train.target_names})
res = [d_train, d_test]
if partitions_proportions:
res = redivide_data([d_train, d_test], partition_proportions=partitions_proportions, shuffle=False)
if as_tensor: [dat.convert_to_tensor() for dat in res]
return Datasets.from_list(res)
评论列表
文章目录