def inference(video_id_batch, prediction_batch, label_batch, saver, out_file_location):
global_step_val = -1
with tf.Session() as sess:
if FLAGS.model_checkpoint_path:
checkpoint = FLAGS.model_checkpoint_path
else:
checkpoint = tf.train.latest_checkpoint(FLAGS.train_dir)
if checkpoint:
logging.info("Loading checkpoint for eval: " + checkpoint)
# Restores from checkpoint
saver.restore(sess, checkpoint)
# Assuming model_checkpoint_path looks something like:
# /my-favorite-path/yt8m_train/model.ckpt-0, extract global_step from it.
global_step_val = checkpoint.split("/")[-1].split("-")[-1]
else:
logging.info("No checkpoint file found.")
return global_step_val
sess.run([tf.local_variables_initializer()])
# Workaround for num_epochs issue.
def set_up_init_ops(variables):
init_op_list = []
for variable in list(variables):
if "train_input" in variable.name:
init_op_list.append(tf.assign(variable, 1))
variables.remove(variable)
init_op_list.append(tf.variables_initializer(variables))
return init_op_list
sess.run(set_up_init_ops(tf.get_collection_ref(
tf.GraphKeys.LOCAL_VARIABLES)))
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
input_indices = np.eye(4716)
try:
print("start saving parameters")
predictions = sess.run(prediction_batch, feed_dict={label_batch: input_indices})
np.savetxt(out_file_location, predictions)
except tf.errors.OutOfRangeError:
logging.info('Done with inference. The output file was written to ' + out_file_location)
finally:
coord.request_stop()
coord.join(threads)
sess.close()
评论列表
文章目录