make_lmdb.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:hyperband_benchmarks 作者: lishal 项目源码 文件源码
def get_data():
    lmdb_env = lmdb.open('/home/lisha/school/caffe/examples/cifar10/cifar10_train_lmdb//')
    lmdb_txn = lmdb_env.begin()
    lmdb_cursor = lmdb_txn.cursor()
    datum = caffe.proto.caffe_pb2.Datum()

    x=[]
    y=[]
    for key, value in lmdb_cursor:
        datum.ParseFromString(value)
        label = datum.label
        data = caffe.io.datum_to_array(datum)
        x.append(data)
        y.append(label)

    x=np.array(x)
    y=np.array(y)



    map_size = int(1e12)
    # Don't need to shuffle 3 times
    sss = StratifiedShuffleSplit(y, 3, test_size=0.2, random_state=0)
    for train_index, test_index in sss:
        ind_train=train_index
        ind_test=test_index
    env = lmdb.open('/home/lisha/school/caffe/examples/cifar10/cifar10_evenval_lmdb2/', map_size=map_size)
    with env.begin(write=True) as txn:
        # txn is a Transaction object
        for i in range(10000):
            im_dat = caffe.io.array_to_datum(x[ind_test][i],y[ind_test][i])
            txn.put('{:0>10d}'.format(i), im_dat.SerializeToString())

    #map_size = x.nbytes * 10
    del env
    env = lmdb.open('/home/lisha/school/caffe/examples/cifar10/cifar10_eventrain_lmdb2/', map_size=map_size)
    with env.begin(write=True) as txn:
        # txn is a Transaction object
        for i in range(40000):
            im_dat = caffe.io.array_to_datum(x[ind_train][i],y[ind_train][i])
            txn.put('{:0>10d}'.format(i), im_dat.SerializeToString())
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号