network.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:cifar10-tensorflow 作者: persistforever 项目源码 文件源码
def test(self, backup_path, epoch, batch_size=128):
        saver = tf.train.Saver(write_version=tf.train.SaverDef.V2)
        gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.45)
        sess = tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
        # ????
        model_path = os.path.join(backup_path, 'model_%d.ckpt' % (epoch))
        assert(os.path.exists(model_path+'.index'))
        saver.restore(sess, model_path)
        print('read model from %s' % (model_path))
        # ??????????
        precision = []
        for batch in range(int(cifar10.test.num_examples / batch_size)):
            batch_image, batch_label = cifar10.test.next_batch(batch_size)
            [precision_onebatch] = sess.run(
                fetches=[self.accuracy], 
                feed_dict={self.image:batch_image, 
                           self.label:batch_label,
                           self.keep_prob:1.0})
            precision.append(precision_onebatch)
        print('test precision: %.4f' % (numpy.mean(precision)))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号