def text_cnn_load_model_and_eval_v2(x_test_s1,
x_test_s2,
checkpoint_file,
allow_soft_placement,
log_device_placement,
embeddings):
graph = tf.Graph()
with graph.as_default():
session_conf = tf.ConfigProto(
allow_soft_placement=allow_soft_placement,
log_device_placement=log_device_placement)
sess = tf.Session(config=session_conf)
with sess.as_default():
# Load the saved meta graph and restore variables
saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file))
saver.restore(sess, checkpoint_file)
# Get the placeholders from the graph by name
input_x_s1 = graph.get_operation_by_name("input_x_s1").outputs[0]
input_x_s2 = graph.get_operation_by_name("input_x_s2").outputs[0]
# input_y = graph.get_operation_by_name("input_y").outputs[0]
dropout_keep_prob = graph.get_operation_by_name("dropout_keep_prob").outputs[0]
# Tensors we want to evaluate
predictions = graph.get_operation_by_name("output/predictions").outputs[0]
# Generate batches for one epoch
batch_size = 50
batches = data_helpers.batch_iter(list(zip(x_test_s1, x_test_s2)), batch_size, 1, shuffle=False)
# Collect the predictions here
all_predictions = []
# Load embeddings placeholder
embedding_size = embeddings.shape[1]
embeddings_number = embeddings.shape[0]
print 'embedding_size:%s, embeddings_number:%s' % (embedding_size, embeddings_number)
# with tf.name_scope("embedding"):
# embeddings_placeholder = tf.placeholder(tf.float32, shape=[embeddings_number, embedding_size])
embeddings_placeholder = graph.get_operation_by_name("embedding/Placeholder").outputs[0]
for batch in batches:
x_test_batch_s1, x_test_batch_s2 = zip(*batch)
batch_predictions = sess.run(predictions, {input_x_s1: x_test_batch_s1,
input_x_s2: x_test_batch_s2,
dropout_keep_prob: 1.0,
embeddings_placeholder: embeddings})
all_predictions = np.concatenate([all_predictions, batch_predictions])
return all_predictions
评论列表
文章目录