resample_data.py 文件源码

python
阅读 17 收藏 0 点赞 0 评论 0

项目:CIKM2017 作者: heliarmk 项目源码 文件源码
def main():
    # Get the data.
    trains = joblib.load("../data/CIKM2017_train/train_Imp_3x3.pkl")
    #testa_set = joblib.load("../data/CIKM2017_testA/testA_Imp_3x3_del_height_no.4.pkl")
    #testa_x = []

    #for item in testa_set:
    #    testa_x.append(item["input"])

    #testa_x = np.asarray(testa_x, dtype=np.int16).transpose((0,1,3,4,2))
    train_x, train_y, train_class = sample(trains)
    '''
    for i in range(10):
        np.random.shuffle(data_set)
    valid_data_num = int(len(data_set) / 10) #get 10% data for validation
    for i in range(10):
        valid_set = data_set[i * valid_data_num : (i+1) * valid_data_num ]
        train_set = data_set[0: i*valid_data_num]
        train_set.extend(data_set[(i+1)*valid_data_num:])
        train_out, train_mean, train_std = preprocessing(train_set, 0, 0, True )
        valid_out = preprocessing(valid_set, train_mean, train_std)

        testa_out = preprocessing(testa_set, train_mean, train_std)

        convert_to(train_out, "train_Imp_3x3_resample_normalization_"+str(i)+"_fold", is_test=False)
        convert_to(valid_out, "valid_Imp_3x3_resample_normalization_"+str(i)+"_fold", is_test=False)
        convert_to(testa_out, "testA_Imp_3x3_normalization_"+str(i)+"_fold", is_test=True)
    #joblib.dump(value=data_set, filename="../data/CIKM2017_train/train_Imp_3x3_classified_del_height_no.4.pkl",compress=3)
    '''
    h5fname = "../data/CIKM2017_train/train_Imp_3x3.h5"
    import h5py
    "write file"
    with h5py.File(h5fname, "w") as f:
        #f.create_dataset(name="testa_set_x", shape=testa_x.shape, data=testa_x, dtype=testa_x.dtype, compression="lzf", chunks=True)
        f.create_dataset(name="train_set_x", shape=train_x.shape, data=train_x, dtype=train_x.dtype, compression="lzf", chunks=True)
        f.create_dataset(name="train_set_y", shape=train_y.shape, data=train_y, dtype=train_y.dtype, compression="lzf", chunks=True)
        f.create_dataset(name="train_set_class", shape=train_class.shape, data=train_class, dtype=train_class.dtype, compression="lzf", chunks=True)

    return
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号