summary_utils.py 文件源码

python
阅读 17 收藏 0 点赞 0 评论 0

项目:sact 作者: mfigurnov 项目源码 文件源码
def export_to_h5(checkpoint_dir, export_path, images, end_points, num_samples,
                 batch_size, sact):
  """Exports ponder cost maps and other useful info to an HDF5 file."""
  output_file = h5py.File(export_path, 'w')

  output_file.attrs['block_scopes'] = end_points['block_scopes']
  keys_to_tensors = {}
  for block_scope in end_points['block_scopes']:
    for k in ('{}/ponder_cost'.format(block_scope),
              '{}/num_units'.format(block_scope),
              '{}/halting_distribution'.format(block_scope),
              '{}/flops'.format(block_scope)):
      keys_to_tensors[k] = end_points[k]
  keys_to_tensors['images'] = images
  keys_to_tensors['flops'] = end_points['flops']

  if sact:
    keys_to_tensors['ponder_cost_map'] = sact_map(end_points, 'ponder_cost')
    keys_to_tensors['num_units_map'] = sact_map(end_points, 'num_units')

  keys_to_datasets = {}
  for key, tensor in keys_to_tensors.iteritems():
    sh = tensor.get_shape().as_list()
    sh[0] = num_samples
    print(key, sh)
    keys_to_datasets[key] = output_file.create_dataset(
        key, sh, compression='lzf')

  variables_to_restore = slim.get_model_variables()
  checkpoint_path = tf.train.latest_checkpoint(checkpoint_dir)
  assert checkpoint_path is not None
  init_fn = slim.assign_from_checkpoint_fn(checkpoint_path,
                                           variables_to_restore)

  sv = tf.train.Supervisor(
      graph=tf.get_default_graph(),
      logdir=None,
      summary_op=None,
      summary_writer=None,
      global_step=None,
      saver=None)

  assert num_samples % batch_size == 0
  num_batches = num_samples // batch_size

  with sv.managed_session('', start_standard_services=False) as sess:
    init_fn(sess)
    sv.start_queue_runners(sess)

    for i in range(num_batches):
      tf.logging.info('Evaluating batch %d/%d', i + 1, num_batches)
      end_points_out = sess.run(keys_to_tensors)
      for key, dataset in keys_to_datasets.iteritems():
        dataset[i * batch_size:(i + 1) * batch_size, ...] = end_points_out[key]
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号