def export_to_h5(checkpoint_dir, export_path, images, end_points, num_samples,
batch_size, sact):
"""Exports ponder cost maps and other useful info to an HDF5 file."""
output_file = h5py.File(export_path, 'w')
output_file.attrs['block_scopes'] = end_points['block_scopes']
keys_to_tensors = {}
for block_scope in end_points['block_scopes']:
for k in ('{}/ponder_cost'.format(block_scope),
'{}/num_units'.format(block_scope),
'{}/halting_distribution'.format(block_scope),
'{}/flops'.format(block_scope)):
keys_to_tensors[k] = end_points[k]
keys_to_tensors['images'] = images
keys_to_tensors['flops'] = end_points['flops']
if sact:
keys_to_tensors['ponder_cost_map'] = sact_map(end_points, 'ponder_cost')
keys_to_tensors['num_units_map'] = sact_map(end_points, 'num_units')
keys_to_datasets = {}
for key, tensor in keys_to_tensors.iteritems():
sh = tensor.get_shape().as_list()
sh[0] = num_samples
print(key, sh)
keys_to_datasets[key] = output_file.create_dataset(
key, sh, compression='lzf')
variables_to_restore = slim.get_model_variables()
checkpoint_path = tf.train.latest_checkpoint(checkpoint_dir)
assert checkpoint_path is not None
init_fn = slim.assign_from_checkpoint_fn(checkpoint_path,
variables_to_restore)
sv = tf.train.Supervisor(
graph=tf.get_default_graph(),
logdir=None,
summary_op=None,
summary_writer=None,
global_step=None,
saver=None)
assert num_samples % batch_size == 0
num_batches = num_samples // batch_size
with sv.managed_session('', start_standard_services=False) as sess:
init_fn(sess)
sv.start_queue_runners(sess)
for i in range(num_batches):
tf.logging.info('Evaluating batch %d/%d', i + 1, num_batches)
end_points_out = sess.run(keys_to_tensors)
for key, dataset in keys_to_datasets.iteritems():
dataset[i * batch_size:(i + 1) * batch_size, ...] = end_points_out[key]
评论列表
文章目录