uci_yeast.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:gcForest 作者: kingfengji 项目源码 文件源码
def load_data():
    id2label = {}
    label2id = {}
    label_path = osp.abspath( osp.join(get_dataset_base(), "uci_yeast", "yeast.label") )
    with open(label_path) as f:
        for row in f:
            cols = row.strip().split(" ")
            id2label[int(cols[0])] = cols[1]
            label2id[cols[1]] = int(cols[0])

    data_path = osp.abspath( osp.join(get_dataset_base(), "uci_yeast", "yeast.data") )
    with open(data_path) as f:
        rows = f.readlines()
    n_datas = len(rows)
    X = np.zeros((n_datas, 8), dtype=np.float32)
    y = np.zeros(n_datas, dtype=np.int32)
    for i, row in enumerate(rows):
        cols = re.split(" +", row.strip())
        #print(list(map(float, cols[1:1+8])))
        X[i,:] = list(map(float, cols[1:1+8]))
        y[i] = label2id[cols[-1]]
    train_idx, test_idx = train_test_split(range(n_datas), random_state=0, train_size=0.7, stratify=y)
    return (X[train_idx], y[train_idx]), (X[test_idx], y[test_idx])
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号