mrbi_to_lmdb.py 文件源码

python
阅读 46 收藏 0 点赞 0 评论 0

项目:hyperband_benchmarks 作者: lishal 项目源码 文件源码
def make_train_val():
    print 'Loading Matlab data.'
    f =  '/home/lisha/school/Projects/hyperband_nnet/hyperband2/mrbi/mnist_rotation_back_image_new/mnist_all_background_images_rotation_normalized_train_valid.amat'

    X,Y=get_data(f)
    N = Y.shape[0]
    map_size = X.nbytes*2
    #if you want to shuffle your data
    #random.shuffle(N)

    sss = StratifiedShuffleSplit(Y, 3, test_size=2000, random_state=0)
    for train_index, test_index in sss:
        ind_train1=train_index
        ind_val1=test_index
    print len(ind_train1),len(ind_val1)
    env = lmdb.open('/home/lisha/school/Projects/hyperband_nnet/hyperband2/mrbi/mrbi_train', map_size=map_size*5/6)
    with env.begin(write=True) as txn:
        # txn is a Transaction object
        for i in range(len(ind_train1)):
            im_dat = caffe.io.array_to_datum(X[ind_train1[i]],Y[ind_train1[i]])
            txn.put('{:0>10d}'.format(i), im_dat.SerializeToString())
    env = lmdb.open('/home/lisha/school/Projects/hyperband_nnet/hyperband2/mrbi/mrbi_val', map_size=map_size/6)
    with env.begin(write=True) as txn:
        # txn is a Transaction object
        for i in range(len(ind_val1)):
            im_dat = caffe.io.array_to_datum(X[ind_val1[i]],Y[ind_val1[i]])
            txn.put('{:0>10d}'.format(i), im_dat.SerializeToString())
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号