def main():
data_pbtxt = sys.argv[1]
output_dir = sys.argv[2]
prefix = sys.argv[3]
r = int(sys.argv[4])
gpu_mem = sys.argv[5]
main_mem = sys.argv[6]
if not os.path.isdir(output_dir):
os.makedirs(output_dir)
rep_dict, stats_files = MakeDict(data_pbtxt)
reps = rep_dict.keys()
indices_file = os.path.join(prefix, 'splits', 'train_indices_%d.npy' % r)
if os.path.exists(indices_file):
train = np.load(indices_file)
valid = np.load(os.path.join(prefix, 'splits', 'valid_indices_%d.npy' % r))
test = np.load(os.path.join(prefix, 'splits', 'test_indices_%d.npy' % r))
else:
print 'Creating new split.'
indices = np.arange(25000)
np.random.shuffle(indices)
train = indices[:10000]
valid = indices[10000:15000]
test = indices[15000:]
np.save(os.path.join(prefix, 'splits', 'train_indices_%d.npy' % r), train)
np.save(os.path.join(prefix, 'splits', 'valid_indices_%d.npy' % r), valid)
np.save(os.path.join(prefix, 'splits', 'test_indices_%d.npy' % r), test)
print 'Splitting data'
dataset_pb = deepnet_pb2.Dataset()
dataset_pb.name = 'flickr_split_%d' % r
dataset_pb.gpu_memory = gpu_mem
dataset_pb.main_memory = main_mem
for rep in reps:
data = rep_dict[rep]
stats_file = stats_files[rep]
DumpDataSplit(data[train], output_dir, 'train_%s' % rep, dataset_pb, stats_file)
DumpDataSplit(data[valid], output_dir, 'valid_%s' % rep, dataset_pb, stats_file)
DumpDataSplit(data[test], output_dir, 'test_%s' % rep, dataset_pb, stats_file)
print 'Splitting labels'
labels = np.load(os.path.join(prefix, 'labels.npy')).astype('float32')
DumpLabelSplit(labels[train,], output_dir, 'train_labels', dataset_pb)
DumpLabelSplit(labels[valid,], output_dir, 'valid_labels', dataset_pb)
DumpLabelSplit(labels[test,], output_dir, 'test_labels', dataset_pb)
#d = 'indices'
#np.save(os.path.join(output_dir, 'train_%s.npy' % d), train)
#np.save(os.path.join(output_dir, 'valid_%s.npy' % d), valid)
#np.save(os.path.join(output_dir, 'test_%s.npy' % d), test)
with open(os.path.join(output_dir, 'data.pbtxt'), 'w') as f:
text_format.PrintMessage(dataset_pb, f)
print 'Output written in directory %s' % output_dir
评论列表
文章目录