def inference(reader,train_dir, data_pattern, out_file_location, batch_size, top_k):
with tf.Session() as sess, gfile.Open(out_file_location, "w+") as out_file:
video_id_batch, video_batch, num_frames_batch = get_input_data_tensors(reader, data_pattern, batch_size)
if FLAGS.model_checkpoint_path:
latest_checkpoint = FLAGS.model_checkpoint_path
else:
latest_checkpoint = tf.train.latest_checkpoint(train_dir)
if latest_checkpoint is None:
raise Exception("unable to find a checkpoint at location: %s" % train_dir)
else:
meta_graph_location = latest_checkpoint + ".meta"
logging.info("loading meta-graph: " + meta_graph_location)
saver = tf.train.import_meta_graph(meta_graph_location, clear_devices=True)
logging.info("restoring variables from " + latest_checkpoint)
saver.restore(sess, latest_checkpoint)
parameters = get_forward_parameters(vocab_size=reader.num_classes)
# Workaround for num_epochs issue.
def set_up_init_ops(variables):
init_op_list = []
for variable in list(variables):
if "train_input" in variable.name:
init_op_list.append(tf.assign(variable, 1))
variables.remove(variable)
init_op_list.append(tf.variables_initializer(variables))
return init_op_list
sess.run(set_up_init_ops(tf.get_collection_ref(
tf.GraphKeys.LOCAL_VARIABLES)))
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
try:
print("start saving parameters")
params = sess.run(parameters)
print(params)
for i in range(len(params)):
np.savetxt(FLAGS.train_dir+'/autoencoder_layer%d.model' % i, params[i])
except tf.errors.OutOfRangeError:
logging.info('Done with inference. The output file was written to ' + out_file_location)
finally:
coord.request_stop()
coord.join(threads)
sess.close()
评论列表
文章目录