resnet50.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:npfl114 作者: ufal 项目源码 文件源码
def __init__(self, checkpoint, threads):
        # Create the session
        self.session = tf.Session(graph = tf.Graph(), config=tf.ConfigProto(inter_op_parallelism_threads=threads,
                                                                            intra_op_parallelism_threads=threads))

        with self.session.graph.as_default():
            # Construct the model
            self.images = tf.placeholder(tf.float32, [None, self.HEIGHT, self.WIDTH, 3])

            with tf_slim.arg_scope(tf_slim.nets.resnet_v1.resnet_arg_scope(is_training=False)):
                resnet, _ = tf_slim.nets.resnet_v1.resnet_v1_50(self.images, num_classes = self.CLASSES)

            self.predictions = tf.argmax(tf.squeeze(resnet, [1, 2]), 1)

            # Load the checkpoint
            self.saver = tf.train.Saver()
            self.saver.restore(self.session, checkpoint)

            # JPG loading
            self.jpeg_file = tf.placeholder(tf.string, [])
            self.jpeg_data = tf.image.resize_image_with_crop_or_pad(tf.image.decode_jpeg(tf.read_file(self.jpeg_file), channels=3), self.HEIGHT, self.WIDTH)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号