def test(net, image_name):
image = build_image(image_name)
with tf.Session() as sess:
saver = tf.train.Saver(tf.all_variables())
model_file = tf.train.latest_checkpoint('./model/')
if model_file:
saver.restore(sess, model_file)
else:
raise Exception('Testing needs pre-trained model!')
feed_dict = {
net['image']: image,
net['drop_rate']: 1
}
result = sess.run(tf.argmax(net['score'], dimension=3),
feed_dict=feed_dict)
return result
评论列表
文章目录