def inference(reader, train_dir, data_pattern, out_file_location, batch_size):
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess, gfile.Open(out_file_location, "w+") as out_file:
image_id_batch, image_batch = get_input_data_tensors(reader, data_pattern, batch_size)
latest_checkpoint = tf.train.latest_checkpoint(train_dir)
if latest_checkpoint is None:
raise Exception("unable to find a checkpoint at location: %s" % train_dir)
else:
meta_graph_location = latest_checkpoint + ".meta"
logging.info("loading meta-graph: " + meta_graph_location)
saver = tf.train.import_meta_graph(meta_graph_location, clear_devices=True)
logging.info("restoring variables from " + latest_checkpoint)
saver.restore(sess, latest_checkpoint)
input_tensor = tf.get_collection("input_batch_raw")[0]
predictions_tensor = tf.get_collection("predictions")[0]
# Workaround for num_epochs issue.
def set_up_init_ops(variables):
init_op_list = []
for variable in list(variables):
if "train_input" in variable.name:
init_op_list.append(tf.assign(variable, 1))
variables.remove(variable)
init_op_list.append(tf.variables_initializer(variables))
return init_op_list
sess.run(set_up_init_ops(tf.get_collection_ref(
tf.GraphKeys.LOCAL_VARIABLES)))
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
num_examples_processed = 0
start_time = time.time()
out_file.write("Id,Category\n")
try:
while not coord.should_stop():
image_id_batch_val, image_batch_val = sess.run([image_id_batch, image_batch])
predictions_val, = sess.run([predictions_tensor], feed_dict={input_tensor: image_batch_val})
now = time.time()
num_examples_processed += len(image_batch_val)
num_classes = predictions_val.shape[1]
logging.info("num examples processed: " + str(num_examples_processed) + " elapsed seconds: " + "{0:.2f}".format(now-start_time))
for line in format_lines(image_id_batch_val, predictions_val):
out_file.write(line)
out_file.flush()
except tf.errors.OutOfRangeError:
logging.info('Done with inference. The output file was written to ' + out_file_location)
finally:
coord.request_stop()
coord.join(threads)
sess.close()
inference.py 文件源码
python
阅读 27
收藏 0
点赞 0
评论 0
评论列表
文章目录