def predict():
"""Predict unseen images"""
"""Step 0: load data and trained model"""
mnist = input_data.read_data_sets("./data/", one_hot=True)
checkpoint_dir = sys.argv[1]
"""Step 1: build the rnn model"""
x = tf.placeholder("float", [None, n_steps, n_input])
y = tf.placeholder("float", [None, n_classes])
weights = tf.Variable(tf.random_normal([n_hidden, n_classes]), name='weights')
biases = tf.Variable(tf.random_normal([n_classes]), name='biases')
pred = rnn_model(x, weights, biases)
correct_pred = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
"""Step 2: predict new images with the trained model"""
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
"""Step 2.0: load the trained model"""
checkpoint_file = tf.train.latest_checkpoint(checkpoint_dir + 'checkpoints')
print('Loaded the trained model: {}'.format(checkpoint_file))
saver = tf.train.Saver()
saver.restore(sess, checkpoint_file)
"""Step 2.1: predict new data"""
test_len = 500
test_data = mnist.test.images[:test_len].reshape((-1, n_steps, n_input))
test_label = mnist.test.labels[:test_len]
print("Testing Accuracy:", sess.run(accuracy, feed_dict={x: test_data, y: test_label}))
评论列表
文章目录