datasets.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:RFHO 作者: lucfra 项目源码 文件源码
def load_20newsgroup_vectorized(folder=SCIKIT_LEARN_DATA, one_hot=True, partitions_proportions=None,
                                shuffle=False, binary_problem=False, as_tensor=True, minus_value=-1.):
    data_train = sk_dt.fetch_20newsgroups_vectorized(data_home=folder, subset='train')
    data_test = sk_dt.fetch_20newsgroups_vectorized(data_home=folder, subset='test')

    X_train = data_train.data
    X_test = data_test.data
    y_train = data_train.target
    y_test = data_test.target
    if binary_problem:
        y_train[data_train.target < 10] = minus_value
        y_train[data_train.target >= 10] = 1.
        y_test[data_test.target < 10] = minus_value
        y_test[data_test.target >= 10] = 1.
    if one_hot:
        y_train = to_one_hot_enc(y_train)
        y_test = to_one_hot_enc(y_test)

    # if shuffle and sk_shuffle:
    #     xtr = X_train.tocoo()
    #     xts = X_test.tocoo()

    d_train = Dataset(data=X_train,
                      target=y_train, info={'target names': data_train.target_names})
    d_test = Dataset(data=X_test,
                     target=y_test, info={'target names': data_train.target_names})
    res = [d_train, d_test]
    if partitions_proportions:
        res = redivide_data([d_train, d_test], partition_proportions=partitions_proportions, shuffle=False)

    if as_tensor: [dat.convert_to_tensor() for dat in res]

    return Datasets.from_list(res)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号