split_reps.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:Twitter100k 作者: huyt16 项目源码 文件源码
def main():
  data_pbtxt = sys.argv[1]
  output_dir = sys.argv[2]
  prefix = sys.argv[3]
  r = int(sys.argv[4])
  gpu_mem = sys.argv[5]
  main_mem = sys.argv[6]
  if not os.path.isdir(output_dir):
    os.makedirs(output_dir)

  rep_dict, stats_files = MakeDict(data_pbtxt)
  reps = rep_dict.keys()

  indices_file = os.path.join(prefix, 'splits', 'train_indices_%d.npy' % r)
  if os.path.exists(indices_file):
    train = np.load(indices_file)
    valid = np.load(os.path.join(prefix, 'splits', 'valid_indices_%d.npy' % r))
    test = np.load(os.path.join(prefix, 'splits', 'test_indices_%d.npy' % r))
  else:
    print 'Creating new split.'
    indices = np.arange(25000)
    np.random.shuffle(indices)
    train = indices[:10000]
    valid = indices[10000:15000]
    test = indices[15000:]
    np.save(os.path.join(prefix, 'splits', 'train_indices_%d.npy' % r), train)
    np.save(os.path.join(prefix, 'splits', 'valid_indices_%d.npy' % r), valid)
    np.save(os.path.join(prefix, 'splits', 'test_indices_%d.npy' % r), test)


  print 'Splitting data'
  dataset_pb = deepnet_pb2.Dataset()
  dataset_pb.name = 'flickr_split_%d' % r
  dataset_pb.gpu_memory = gpu_mem
  dataset_pb.main_memory = main_mem
  for rep in reps:
    data = rep_dict[rep]
    stats_file = stats_files[rep]
    DumpDataSplit(data[train], output_dir, 'train_%s' % rep, dataset_pb, stats_file)
    DumpDataSplit(data[valid], output_dir, 'valid_%s' % rep, dataset_pb, stats_file)
    DumpDataSplit(data[test], output_dir, 'test_%s' % rep, dataset_pb, stats_file)

  print 'Splitting labels'
  labels = np.load(os.path.join(prefix, 'labels.npy')).astype('float32')
  DumpLabelSplit(labels[train,], output_dir, 'train_labels', dataset_pb)
  DumpLabelSplit(labels[valid,], output_dir, 'valid_labels', dataset_pb)
  DumpLabelSplit(labels[test,], output_dir, 'test_labels', dataset_pb)

  #d = 'indices'
  #np.save(os.path.join(output_dir, 'train_%s.npy' % d), train)
  #np.save(os.path.join(output_dir, 'valid_%s.npy' % d), valid)
  #np.save(os.path.join(output_dir, 'test_%s.npy' % d), test)

  with open(os.path.join(output_dir, 'data.pbtxt'), 'w') as f:
    text_format.PrintMessage(dataset_pb, f)

  print 'Output written in directory %s' % output_dir
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号