def text_cnn_load_model_and_eval(x_test,
checkpoint_file,
allow_soft_placement,
log_device_placement,
embeddings):
graph = tf.Graph()
with graph.as_default():
session_conf = tf.ConfigProto(
allow_soft_placement=allow_soft_placement,
log_device_placement=log_device_placement)
sess = tf.Session(config=session_conf)
with sess.as_default():
# Load the saved meta graph and restore variables
saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_file))
saver.restore(sess, checkpoint_file)
# Get the placeholders from the graph by name
input_x = graph.get_operation_by_name("input_x").outputs[0]
# input_y = graph.get_operation_by_name("input_y").outputs[0]
dropout_keep_prob = graph.get_operation_by_name("dropout_keep_prob").outputs[0]
# Tensors we want to evaluate
predictions = graph.get_operation_by_name("output/predictions").outputs[0]
# Generate batches for one epoch
batch_size = 50
batches = data_helpers.batch_iter(x_test, batch_size, 1, shuffle=False)
# Collect the predictions here
all_predictions = []
# Load embeddings placeholder
embedding_size = embeddings.shape[1]
embeddings_number = embeddings.shape[0]
print 'embedding_size:%s, embeddings_number:%s' % (embedding_size, embeddings_number)
# with tf.name_scope("embedding"):
# embeddings_placeholder = tf.placeholder(tf.float32, shape=[embeddings_number, embedding_size])
embeddings_placeholder = graph.get_operation_by_name("embedding/Placeholder").outputs[0]
for x_test_batch in batches:
batch_predictions = sess.run(predictions, {input_x: x_test_batch, dropout_keep_prob: 1.0,
embeddings_placeholder: embeddings})
all_predictions = np.concatenate([all_predictions, batch_predictions])
return all_predictions
评论列表
文章目录