model.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:interprettensor 作者: VigneshSrinivasan10 项目源码 文件源码
def test():

    mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True)

    with tf.Session() as sess:
        x = tf.placeholder(tf.float32, [FLAGS.batch_size, 784], name='input')
        with tf.variable_scope('model'):
            my_netowrk = layers()
            output = my_netowrk.forward(x)
            if FLAGS.relevance:
                RELEVANCE = my_netowrk.lrp(output, 'simple', 1.0)

        # Merge all the summaries and write them out 
        merged = tf.summary.merge_all()
        test_writer = tf.summary.FileWriter(FLAGS.summaries_dir + '/my_model')

        # Intialize variables and reload your model
        saver = init_vars(sess)

        # Extract testing data 
        xs, ys = mnist.test.next_batch(FLAGS.batch_size)
        # Pass the test data to the restored model
        summary, relevance_test= sess.run([merged, RELEVANCE], feed_dict={x:(2*xs)-1})
        test_writer.add_summary(summary, 0)

        # Save the images as heatmaps to visualize on tensorboard
        images = xs.reshape([FLAGS.batch_size,28,28,1])
        images = (images + 1)/2.0
        relevances = relevance_test.reshape([FLAGS.batch_size,28,28,1])
        plot_relevances(relevances, images, test_writer )

        test_writer.close()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号