def test():
mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)
with tf.Session() as sess:
x = tf.placeholder(tf.float32, [FLAGS.batch_size, 784], name='input')
with tf.variable_scope('model'):
my_netowrk = layers()
output = my_netowrk.forward(x)
if FLAGS.relevance:
RELEVANCE = my_netowrk.lrp(output, 'simple', 1.0)
# Merge all the summaries and write them out
merged = tf.summary.merge_all()
test_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/my_model')
# Intialize variables and reload your model
saver = init_vars(sess)
# Extract testing data
xs, ys = mnist.test.next_batch(FLAGS.batch_size)
# Pass the test data to the restored model
summary, relevance_test= sess.run([merged, RELEVANCE], feed_dict={x:(2*xs)-1})
test_writer.add_summary(summary, 0)
# Save the images as heatmaps to visualize on tensorboard
images = xs.reshape([FLAGS.batch_size,28,28,1])
images = (images + 1)/2.0
relevances = relevance_test.reshape([FLAGS.batch_size,28,28,1])
plot_relevances(relevances, images, test_writer )
test_writer.close()
评论列表
文章目录