mat_to_lmdb.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:hyperband_benchmarks 作者: lishal 项目源码 文件源码
def make_train_val():
    print 'Loading Matlab data.'
    f1 = scipy.io.loadmat('/home/lisha/school/Projects/hyperband_nnet/hyperband2/svhn/svhn_data/train_32x32.mat')
    f2 = scipy.io.loadmat('/home/lisha/school/Projects/hyperband_nnet/hyperband2/svhn/svhn_data/extra_32x32.mat')
    # name of your matlab variables:
    data_train = f1.get('X')
    labels_train = f1.get('y')
    data_extra=f2.get('X')
    labels_extra = f2.get('y')
    sss = StratifiedShuffleSplit(labels_train, 3, test_size=0.05460229056, random_state=0)
    for train_index, test_index in sss:
        ind_train1=train_index
        ind_val1=test_index
    sss = StratifiedShuffleSplit(labels_extra, 3, test_size=0.00376554936, random_state=1)
    for train_index, test_index in sss:
        ind_train2=train_index
        ind_val2=test_index
    print 'val: '+str(len(ind_val1)+len(ind_val2))+' train: '+str(len(ind_train1)+len(ind_train2))
    Y1= np.array(labels_train,dtype=int)
    Y1[Y1==10]=0
    Y1=Y1.flatten()
    Y2= np.array(labels_extra,dtype=int)
    Y2[Y2==10]=0
    Y2=Y2.flatten()

    X1= np.array(data_train)
    X1=np.rollaxis(X1,3,0)

    X2= np.array(data_extra)
    X2=np.rollaxis(X2,3,0)
    map_size_train = X2.nbytes*4
    map_size_val = X1.nbytes*2
    #if you want to shuffle your data
    #random.shuffle(N)
    env = lmdb.open('/home/lisha/school/Projects/hyperband_nnet/hyperband2/svhn/svhn_val', map_size=map_size_val)
    with env.begin(write=True) as txn:
        # txn is a Transaction object
        for i in range(len(ind_val1)):
            im_dat = caffe.io.array_to_datum(np.rollaxis(X1[ind_val1[i]],2,0),Y1[ind_val1[i]])
            txn.put('{:0>10d}'.format(i), im_dat.SerializeToString())
        for i in range(len(ind_val2)):
            im_dat = caffe.io.array_to_datum(np.rollaxis(X2[ind_val2[i]],2,0),Y2[ind_val2[i]])
            txn.put('{:0>10d}'.format(len(ind_val1)+i), im_dat.SerializeToString())
    env = lmdb.open('/home/lisha/school/Projects/hyperband_nnet/hyperband2/svhn/svhn_train', map_size=map_size_train)
    with env.begin(write=True) as txn:
        # txn is a Transaction object
        for i in range(len(ind_train1)):
            im_dat = caffe.io.array_to_datum(np.rollaxis(X1[ind_train1[i]],2,0),Y1[ind_train1[i]])
            txn.put('{:0>10d}'.format(i), im_dat.SerializeToString())
        for i in range(len(ind_train2)):
            im_dat = caffe.io.array_to_datum(np.rollaxis(X2[ind_train2[i]],2,0),Y2[ind_train2[i]])
            txn.put('{:0>10d}'.format(len(ind_train1)+i), im_dat.SerializeToString())
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号