def __init__(self, checkpoint, threads):
# Create the session
self.session = tf.Session(graph = tf.Graph(), config=tf.ConfigProto(inter_op_parallelism_threads=threads,
intra_op_parallelism_threads=threads))
with self.session.graph.as_default():
# Construct the model
self.images = tf.placeholder(tf.float32, [None, self.HEIGHT, self.WIDTH, 3])
with tf_slim.arg_scope(tf_slim.nets.resnet_v1.resnet_arg_scope(is_training=False)):
resnet, _ = tf_slim.nets.resnet_v1.resnet_v1_50(self.images, num_classes = self.CLASSES)
self.predictions = tf.argmax(tf.squeeze(resnet, [1, 2]), 1)
# Load the checkpoint
self.saver = tf.train.Saver()
self.saver.restore(self.session, checkpoint)
# JPG loading
self.jpeg_file = tf.placeholder(tf.string, [])
self.jpeg_data = tf.image.resize_image_with_crop_or_pad(tf.image.decode_jpeg(tf.read_file(self.jpeg_file), channels=3), self.HEIGHT, self.WIDTH)
评论列表
文章目录