pretrained.py 文件源码

python
阅读 18 收藏 0 点赞 0 评论 0

项目:SSD_tensorflow_VOC 作者: LevinJ 项目源码 文件源码
def use_fined_model(self):
        image_size = inception.inception_v4.default_image_size
        batch_size = 3
        flowers_data_dir = "../../data/flower"
        train_dir = '/tmp/inception_finetuned/'

        with tf.Graph().as_default():
            tf.logging.set_verbosity(tf.logging.INFO)

            dataset = flowers.get_split('train', flowers_data_dir)
            images, images_raw, labels = self.load_batch(dataset, height=image_size, width=image_size)

            # Create the model, use the default arg scope to configure the batch norm parameters.
            with slim.arg_scope(inception.inception_v4_arg_scope()):
                logits, _ = inception.inception_v4(images, num_classes=dataset.num_classes, is_training=True)

            probabilities = tf.nn.softmax(logits)

            checkpoint_path = tf.train.latest_checkpoint(train_dir)
            init_fn = slim.assign_from_checkpoint_fn(
              checkpoint_path,
              slim.get_variables_to_restore())

            with tf.Session() as sess:
                with slim.queues.QueueRunners(sess):
                    sess.run(tf.initialize_local_variables())
                    init_fn(sess)
                    np_probabilities, np_images_raw, np_labels = sess.run([probabilities, images_raw, labels])

                    for i in range(batch_size): 
                        image = np_images_raw[i, :, :, :]
                        true_label = np_labels[i]
                        predicted_label = np.argmax(np_probabilities[i, :])
                        predicted_name = dataset.labels_to_names[predicted_label]
                        true_name = dataset.labels_to_names[true_label]

                        plt.figure()
                        plt.imshow(image.astype(np.uint8))
                        plt.title('Ground Truth: [%s], Prediction [%s]' % (true_name, predicted_name))
                        plt.axis('off')
                        plt.show()
                return
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号