datasets.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:RFHO 作者: lucfra 项目源码 文件源码
def generate_multiclass_dataset(n_samples=100, n_features=10,
                                n_informative=5, n_redundant=3, n_repeated=2,
                                n_classes=2, n_clusters_per_class=2,
                                weights=None, flip_y=0.01, class_sep=1.0,
                                hypercube=True, shift=0.0, scale=1.0,
                                shuffle=True, random_state=None, hot_encoded=True, partitions_proportions=None,
                                negative_labels=-1.):
    X, y = sk_dt.make_classification(n_samples=n_samples, n_features=n_features,
                                     n_informative=n_informative, n_redundant=n_redundant, n_repeated=n_repeated,
                                     n_classes=n_classes, n_clusters_per_class=n_clusters_per_class,
                                     weights=weights, flip_y=flip_y, class_sep=class_sep,
                                     hypercube=hypercube, shift=shift, scale=scale,
                                     shuffle=True, random_state=random_state)
    if hot_encoded:
        y = to_one_hot_enc(y)
    else:
        y[y == 0] = negative_labels
    res = Dataset(data=np.array(X, dtype=np.float32), target=np.array(y, dtype=np.float32),
                  info={'n_informative': n_informative, 'n_redundant': n_redundant,
                                     'n_repeated': n_repeated,
                                     'n_classes': n_classes, 'n_clusters_per_class': n_clusters_per_class,
                                     'weights': weights, 'flip_y': flip_y, 'class_sep': class_sep,
                                     'hypercube': hypercube, 'shift': shift, 'scale': scale,
                                     'shuffle': True, 'random_state': random_state})
    np.random.seed(random_state)
    if partitions_proportions:
        res = redivide_data([res], shuffle=shuffle, partition_proportions=partitions_proportions)
        res = Datasets.from_list(res)
    return res
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号