def inference(reader, train_dir, data_pattern, out_file_location, batch_size, top_k):
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess, gfile.Open(out_file_location, "w+") as out_file:
video_id_batch, video_batch, num_frames_batch = get_input_data_tensors(reader, data_pattern, batch_size)
latest_checkpoint = tf.train.latest_checkpoint(train_dir)
if latest_checkpoint is None:
raise Exception("unable to find a checkpoint at location: %s" % train_dir)
else:
meta_graph_location = latest_checkpoint + ".meta"
logging.info("loading meta-graph: " + meta_graph_location)
saver = tf.train.import_meta_graph(meta_graph_location, clear_devices=True)
logging.info("restoring variables from " + latest_checkpoint)
saver.restore(sess, latest_checkpoint)
input_tensor = tf.get_collection("input_batch_raw")[0]
num_frames_tensor = tf.get_collection("num_frames")[0]
predictions_tensor = tf.get_collection("predictions")[0]
# Workaround for num_epochs issue.
def set_up_init_ops(variables):
init_op_list = []
for variable in list(variables):
if "train_input" in variable.name:
init_op_list.append(tf.assign(variable, 1))
variables.remove(variable)
init_op_list.append(tf.variables_initializer(variables))
return init_op_list
sess.run(set_up_init_ops(tf.get_collection_ref(
tf.GraphKeys.LOCAL_VARIABLES)))
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
num_examples_processed = 0
start_time = time.time()
out_file.write("VideoId,LabelConfidencePairs\n")
try:
while not coord.should_stop():
video_id_batch_val, video_batch_val,num_frames_batch_val = sess.run([video_id_batch, video_batch, num_frames_batch])
predictions_val, = sess.run([predictions_tensor], feed_dict={input_tensor: video_batch_val, num_frames_tensor: num_frames_batch_val})
now = time.time()
num_examples_processed += len(video_batch_val)
num_classes = predictions_val.shape[1]
logging.info("num examples processed: " + str(num_examples_processed) + " elapsed seconds: " + "{0:.2f}".format(now-start_time))
for line in format_lines(video_id_batch_val, predictions_val, top_k):
out_file.write(line)
out_file.flush()
except tf.errors.OutOfRangeError:
logging.info('Done with inference. The output file was written to ' + out_file_location)
finally:
coord.request_stop()
coord.join(threads)
sess.close()
python类get_collection_ref()的实例源码
def inference(reader, train_dir, data_pattern, out_file_location, batch_size):
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess, gfile.Open(out_file_location, "w+") as out_file:
image_id_batch, image_batch = get_input_data_tensors(reader, data_pattern, batch_size)
latest_checkpoint = tf.train.latest_checkpoint(train_dir)
if latest_checkpoint is None:
raise Exception("unable to find a checkpoint at location: %s" % train_dir)
else:
meta_graph_location = latest_checkpoint + ".meta"
logging.info("loading meta-graph: " + meta_graph_location)
saver = tf.train.import_meta_graph(meta_graph_location, clear_devices=True)
logging.info("restoring variables from " + latest_checkpoint)
saver.restore(sess, latest_checkpoint)
input_tensor = tf.get_collection("input_batch_raw")[0]
predictions_tensor = tf.get_collection("predictions")[0]
# Workaround for num_epochs issue.
def set_up_init_ops(variables):
init_op_list = []
for variable in list(variables):
if "train_input" in variable.name:
init_op_list.append(tf.assign(variable, 1))
variables.remove(variable)
init_op_list.append(tf.variables_initializer(variables))
return init_op_list
sess.run(set_up_init_ops(tf.get_collection_ref(
tf.GraphKeys.LOCAL_VARIABLES)))
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
num_examples_processed = 0
start_time = time.time()
out_file.write("Id,Category\n")
try:
while not coord.should_stop():
image_id_batch_val, image_batch_val = sess.run([image_id_batch, image_batch])
predictions_val, = sess.run([predictions_tensor], feed_dict={input_tensor: image_batch_val})
now = time.time()
num_examples_processed += len(image_batch_val)
num_classes = predictions_val.shape[1]
logging.info("num examples processed: " + str(num_examples_processed) + " elapsed seconds: " + "{0:.2f}".format(now-start_time))
for line in format_lines(image_id_batch_val, predictions_val):
out_file.write(line)
out_file.flush()
except tf.errors.OutOfRangeError:
logging.info('Done with inference. The output file was written to ' + out_file_location)
finally:
coord.request_stop()
coord.join(threads)
sess.close()
def inference(reader, train_dir, data_pattern, out_file_location, batch_size):
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess, gfile.Open(out_file_location, "w+") as out_file:
image_id_batch, image_batch = get_input_data_tensors(reader, data_pattern, batch_size)
latest_checkpoint = tf.train.latest_checkpoint(train_dir)
if latest_checkpoint is None:
raise Exception("unable to find a checkpoint at location: %s" % train_dir)
else:
meta_graph_location = latest_checkpoint + ".meta"
logging.info("loading meta-graph: " + meta_graph_location)
saver = tf.train.import_meta_graph(meta_graph_location, clear_devices=True)
logging.info("restoring variables from " + latest_checkpoint)
saver.restore(sess, latest_checkpoint)
input_tensor = tf.get_collection("input_batch_raw")[0]
predictions_tensor = tf.get_collection("predictions")[0]
# Workaround for num_epochs issue.
def set_up_init_ops(variables):
init_op_list = []
for variable in list(variables):
if "train_input" in variable.name:
init_op_list.append(tf.assign(variable, 1))
variables.remove(variable)
init_op_list.append(tf.variables_initializer(variables))
return init_op_list
sess.run(set_up_init_ops(tf.get_collection_ref(
tf.GraphKeys.LOCAL_VARIABLES)))
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
num_examples_processed = 0
start_time = time.time()
out_file.write("Id,Category\n")
try:
while not coord.should_stop():
image_id_batch_val, image_batch_val = sess.run([image_id_batch, image_batch])
predictions_val, = sess.run([predictions_tensor], feed_dict={input_tensor: image_batch_val})
now = time.time()
num_examples_processed += len(image_batch_val)
num_classes = predictions_val.shape[1]
logging.info("num examples processed: " + str(num_examples_processed) + " elapsed seconds: " + "{0:.2f}".format(now-start_time))
for line in format_lines(image_id_batch_val, predictions_val):
out_file.write(line)
out_file.flush()
except tf.errors.OutOfRangeError:
logging.info('Done with inference. The output file was written to ' + out_file_location)
finally:
coord.request_stop()
coord.join(threads)
sess.close()
def inference(reader, train_dir, data_pattern, out_file_location, batch_size, top_k):
with tf.Session() as sess, gfile.Open(out_file_location, "w+") as out_file:
video_id_batch, video_batch, num_frames_batch = get_input_data_tensors(reader, data_pattern, batch_size)
latest_checkpoint = tf.train.latest_checkpoint(train_dir)
if latest_checkpoint is None:
raise Exception("unable to find a checkpoint at location: %s" % train_dir)
else:
meta_graph_location = latest_checkpoint + ".meta"
logging.info("loading meta-graph: " + meta_graph_location)
saver = tf.train.import_meta_graph(meta_graph_location, clear_devices=True)
logging.info("restoring variables from " + latest_checkpoint)
saver.restore(sess, latest_checkpoint)
input_tensor = tf.get_collection("input_batch_raw")[0]
num_frames_tensor = tf.get_collection("num_frames")[0]
predictions_tensor = tf.get_collection("predictions")[0]
# Workaround for num_epochs issue.
def set_up_init_ops(variables):
init_op_list = []
for variable in list(variables):
if "train_input" in variable.name:
init_op_list.append(tf.assign(variable, 1))
variables.remove(variable)
init_op_list.append(tf.variables_initializer(variables))
return init_op_list
sess.run(set_up_init_ops(tf.get_collection_ref(
tf.GraphKeys.LOCAL_VARIABLES)))
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
num_examples_processed = 0
start_time = time.time()
out_file.write("VideoId,LabelConfidencePairs\n")
try:
while not coord.should_stop():
video_id_batch_val, video_batch_val,num_frames_batch_val = sess.run([video_id_batch, video_batch, num_frames_batch])
predictions_val, = sess.run([predictions_tensor], feed_dict={input_tensor: video_batch_val, num_frames_tensor: num_frames_batch_val})
now = time.time()
num_examples_processed += len(video_batch_val)
num_classes = predictions_val.shape[1]
logging.info("num examples processed: " + str(num_examples_processed) + " elapsed seconds: " + "{0:.2f}".format(now-start_time))
for line in format_lines(video_id_batch_val, predictions_val, top_k):
out_file.write(line)
out_file.flush()
except tf.errors.OutOfRangeError:
logging.info('Done with inference. The output file was written to ' + out_file_location)
finally:
coord.request_stop()
coord.join(threads)
sess.close()
def inference(reader, train_dir, data_pattern, out_file_location, batch_size):
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess, gfile.Open(out_file_location, "w+") as out_file:
image_batch = get_input_data_tensors(reader, data_pattern, batch_size)
latest_checkpoint = tf.train.latest_checkpoint(train_dir)
if latest_checkpoint is None:
raise Exception("unable to find a checkpoint at location: %s" % train_dir)
else:
meta_graph_location = latest_checkpoint + ".meta"
logging.info("loading meta-graph: " + meta_graph_location)
saver = tf.train.import_meta_graph(meta_graph_location, clear_devices=True)
logging.info("restoring variables from " + latest_checkpoint)
saver.restore(sess, latest_checkpoint)
input_tensor = tf.get_collection("input_batch_raw")[0]
predictions_tensor = tf.get_collection("predictions")[0]
# Workaround for num_epochs issue.
def set_up_init_ops(variables):
init_op_list = []
for variable in list(variables):
if "train_input" in variable.name:
init_op_list.append(tf.assign(variable, 1))
variables.remove(variable)
init_op_list.append(tf.variables_initializer(variables))
return init_op_list
sess.run(set_up_init_ops(tf.get_collection_ref(
tf.GraphKeys.LOCAL_VARIABLES)))
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
num_examples_processed = 0
start_time = time.time()
out_file.write("Id,Category\n")
try:
line_id = 1
while not coord.should_stop():
image_batch_val = sess.run(image_batch)
predictions_val = sess.run(predictions_tensor, feed_dict={input_tensor: image_batch_val})
now = time.time()
num_examples_processed += len(image_batch_val)
num_classes = predictions_val.shape[1]
logging.info("num examples processed: " + str(num_examples_processed) + " elapsed seconds: " + "{0:.2f}".format(now-start_time))
for line in format_lines(predictions_val):
out_file.write("%d,%s" % (line_id, line))
line_id += 1
out_file.flush()
except tf.errors.OutOfRangeError:
logging.info('Done with inference. The output file was written to ' + out_file_location)
finally:
coord.request_stop()
coord.join(threads)
sess.close()
def inference(reader, train_dir, data_pattern, out_file_location, batch_size, top_k):
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess, gfile.Open(out_file_location, "w+") as out_file:
video_id_batch, video_batch, num_frames_batch = get_input_data_tensors(reader, data_pattern, batch_size)
latest_checkpoint = tf.train.latest_checkpoint(train_dir)
if latest_checkpoint is None:
raise Exception("unable to find a checkpoint at location: %s" % train_dir)
else:
meta_graph_location = latest_checkpoint + ".meta"
logging.info("loading meta-graph: " + meta_graph_location)
saver = tf.train.import_meta_graph(meta_graph_location, clear_devices=True)
logging.info("restoring variables from " + latest_checkpoint)
saver.restore(sess, latest_checkpoint)
input_tensor = tf.get_collection("input_batch_raw")[0]
num_frames_tensor = tf.get_collection("num_frames")[0]
predictions_tensor = tf.get_collection("predictions")[0]
# Workaround for num_epochs issue.
def set_up_init_ops(variables):
init_op_list = []
for variable in list(variables):
if "train_input" in variable.name:
init_op_list.append(tf.assign(variable, 1))
variables.remove(variable)
init_op_list.append(tf.variables_initializer(variables))
return init_op_list
sess.run(set_up_init_ops(tf.get_collection_ref(
tf.GraphKeys.LOCAL_VARIABLES)))
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
num_examples_processed = 0
start_time = time.time()
out_file.write("VideoId,LabelConfidencePairs\n")
try:
while not coord.should_stop():
video_id_batch_val, video_batch_val,num_frames_batch_val = sess.run([video_id_batch, video_batch, num_frames_batch])
predictions_val, = sess.run([predictions_tensor], feed_dict={input_tensor: video_batch_val, num_frames_tensor: num_frames_batch_val})
now = time.time()
num_examples_processed += len(video_batch_val)
num_classes = predictions_val.shape[1]
logging.info("num examples processed: " + str(num_examples_processed) + " elapsed seconds: " + "{0:.2f}".format(now-start_time))
for line in format_lines(video_id_batch_val, predictions_val, top_k):
out_file.write(line)
out_file.flush()
except tf.errors.OutOfRangeError:
logging.info('Done with inference. The output file was written to ' + out_file_location)
finally:
coord.request_stop()
coord.join(threads)
sess.close()
def do_evaluation():
# load data early to get node_type_num
ds = data.load_dataset('data/statements')
hyper.node_type_num = len(ds.word2int)
(compiler, _, _, _, raw_accuracy, batch_size_op) = build_model()
# restorer for embedding matrix
embedding_path = tf.train.latest_checkpoint(hyper.embedding_dir)
if embedding_path is None:
raise ValueError('Path to embedding checkpoint is incorrect: ' + hyper.embedding_dir)
# restorer for other variables
checkpoint_path = tf.train.latest_checkpoint(hyper.train_dir)
if checkpoint_path is None:
raise ValueError('Path to tbcnn checkpoint is incorrect: ' + hyper.train_dir)
restored_vars = tf.get_collection_ref('restored')
restored_vars.append(param.get('We'))
restored_vars.extend(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES))
embeddingRestorer = tf.train.Saver({'embedding/We': param.get('We')})
restorer = tf.train.Saver(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES))
# train loop
total_size, test_gen = ds.get_split('test')
test_set = compiler.build_loom_inputs(test_gen)
with tf.Session() as sess:
# Restore embedding matrix first
embeddingRestorer.restore(sess, embedding_path)
# Restore others
restorer.restore(sess, checkpoint_path)
# Initialize other variables
gvariables = [v for v in tf.global_variables() if v not in tf.get_collection('restored')]
sess.run(tf.variables_initializer(gvariables))
num_epochs = 1 if not hyper.warm_up else 3
for shuffled in td.epochs(test_set, num_epochs):
logger.info('')
logger.info('======================= Evaluation ====================================')
accumulated_accuracy = 0.
start_time = default_timer()
for step, batch in enumerate(td.group_by_batches(shuffled, hyper.batch_size), 1):
feed_dict = {compiler.loom_input_tensor: batch}
accuracy_value, actual_bsize = sess.run([raw_accuracy, batch_size_op], feed_dict)
accumulated_accuracy += accuracy_value * actual_bsize
logger.info('evaluation in progress: running accuracy = %.2f, processed = %d / %d',
accuracy_value, (step - 1) * hyper.batch_size + actual_bsize, total_size)
duration = default_timer() - start_time
total_accuracy = accumulated_accuracy / total_size
logger.info('evaluation accumulated accuracy = %.2f%% (%.1f samples/sec; %.2f seconds)',
total_accuracy * 100, total_size / duration, duration)
logger.info('======================= Evaluation End =================================')
logger.info('')
def inference(reader, train_dir, data_pattern, out_file_location, batch_size, top_k):
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess, gfile.Open(out_file_location, "w+") as out_file:
video_id_batch, video_batch, num_frames_batch = get_input_data_tensors(reader, data_pattern, batch_size)
if FLAGS.checkpoint_name == "":
latest_checkpoint = tf.train.latest_checkpoint(train_dir)
else:
latest_checkpoint = FLAGS.train_dir+"model.ckpt-"+FLAGS.checkpoint_name
if latest_checkpoint is None:
raise Exception("unable to find a checkpoint at location: %s" % train_dir)
else:
meta_graph_location = latest_checkpoint + ".meta"
logging.info("loading meta-graph: " + meta_graph_location)
saver = tf.train.import_meta_graph(meta_graph_location, clear_devices=True)
logging.info("restoring variables from " + latest_checkpoint)
saver.restore(sess, latest_checkpoint)
input_tensor = tf.get_collection("input_batch_raw")[0]
num_frames_tensor = tf.get_collection("num_frames")[0]
predictions_tensor = tf.get_collection("predictions")[0]
# Workaround for num_epochs issue.
def set_up_init_ops(variables):
init_op_list = []
for variable in list(variables):
if "train_input" in variable.name:
init_op_list.append(tf.assign(variable, 1))
variables.remove(variable)
init_op_list.append(tf.variables_initializer(variables))
return init_op_list
sess.run(set_up_init_ops(tf.get_collection_ref(
tf.GraphKeys.LOCAL_VARIABLES)))
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
num_examples_processed = 0
start_time = time.time()
out_file.write("VideoId,LabelConfidencePairs\n")
try:
while not coord.should_stop():
video_id_batch_val, video_batch_val,num_frames_batch_val = sess.run([video_id_batch, video_batch, num_frames_batch])
predictions_val, = sess.run([predictions_tensor], feed_dict={input_tensor: video_batch_val, num_frames_tensor: num_frames_batch_val})
now = time.time()
num_examples_processed += len(video_batch_val)
num_classes = predictions_val.shape[1]
logging.info("num examples processed: " + str(num_examples_processed) + " elapsed seconds: " + "{0:.2f}".format(now-start_time))
for line in format_lines(video_id_batch_val, predictions_val, top_k):
out_file.write(line)
out_file.flush()
except tf.errors.OutOfRangeError:
logging.info('Done with inference. The output file was written to ' + out_file_location)
finally:
coord.request_stop()
coord.join(threads)
sess.close()
def inference(reader, train_dir, data_pattern, out_file_location, batch_size, top_k):
with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess, gfile.Open(out_file_location, "w+") as out_file:
video_id_batch, video_batch, num_frames_batch = get_input_data_tensors(reader, data_pattern, batch_size)
latest_checkpoint = tf.train.latest_checkpoint(train_dir)
if latest_checkpoint is None:
raise Exception("unable to find a checkpoint at location: %s" % train_dir)
else:
meta_graph_location = latest_checkpoint + ".meta"
logging.info("loading meta-graph: " + meta_graph_location)
saver = tf.train.import_meta_graph(meta_graph_location, clear_devices=True)
logging.info("restoring variables from " + latest_checkpoint)
saver.restore(sess, latest_checkpoint)
input_tensor = tf.get_collection("input_batch_raw")[0]
num_frames_tensor = tf.get_collection("num_frames")[0]
predictions_tensor = tf.get_collection("predictions")[0]
# Workaround for num_epochs issue.
def set_up_init_ops(variables):
init_op_list = []
for variable in list(variables):
if "train_input" in variable.name:
init_op_list.append(tf.assign(variable, 1))
variables.remove(variable)
init_op_list.append(tf.variables_initializer(variables))
return init_op_list
sess.run(set_up_init_ops(tf.get_collection_ref(
tf.GraphKeys.LOCAL_VARIABLES)))
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
num_examples_processed = 0
start_time = time.time()
out_file.write("VideoId,LabelConfidencePairs\n")
try:
while not coord.should_stop():
video_id_batch_val, video_batch_val,num_frames_batch_val = sess.run([video_id_batch, video_batch, num_frames_batch])
predictions_val, = sess.run([predictions_tensor], feed_dict={input_tensor: video_batch_val, num_frames_tensor: num_frames_batch_val})
now = time.time()
num_examples_processed += len(video_batch_val)
num_classes = predictions_val.shape[1]
logging.info("num examples processed: " + str(num_examples_processed) + " elapsed seconds: " + "{0:.2f}".format(now-start_time))
for line in format_lines(video_id_batch_val, predictions_val, top_k):
out_file.write(line)
out_file.flush()
except tf.errors.OutOfRangeError:
logging.info('Done with inference. The output file was written to ' + out_file_location)
finally:
coord.request_stop()
coord.join(threads)
sess.close()