def make_train_val():
print 'Loading Matlab data.'
f = '/home/lisha/school/Projects/hyperband_nnet/hyperband2/mrbi/mnist_rotation_back_image_new/mnist_all_background_images_rotation_normalized_train_valid.amat'
X,Y=get_data(f)
N = Y.shape[0]
map_size = X.nbytes*2
#if you want to shuffle your data
#random.shuffle(N)
sss = StratifiedShuffleSplit(Y, 3, test_size=2000, random_state=0)
for train_index, test_index in sss:
ind_train1=train_index
ind_val1=test_index
print len(ind_train1),len(ind_val1)
env = lmdb.open('/home/lisha/school/Projects/hyperband_nnet/hyperband2/mrbi/mrbi_train', map_size=map_size*5/6)
with env.begin(write=True) as txn:
# txn is a Transaction object
for i in range(len(ind_train1)):
im_dat = caffe.io.array_to_datum(X[ind_train1[i]],Y[ind_train1[i]])
txn.put('{:0>10d}'.format(i), im_dat.SerializeToString())
env = lmdb.open('/home/lisha/school/Projects/hyperband_nnet/hyperband2/mrbi/mrbi_val', map_size=map_size/6)
with env.begin(write=True) as txn:
# txn is a Transaction object
for i in range(len(ind_val1)):
im_dat = caffe.io.array_to_datum(X[ind_val1[i]],Y[ind_val1[i]])
txn.put('{:0>10d}'.format(i), im_dat.SerializeToString())
评论列表
文章目录