def build_detector(self):
img_size = self.config['image_size']
self.image_ph = tf.placeholder(shape=[None, None, 3],
dtype=tf.float32, name='img_ph')
self.seg_ph = tf.placeholder(shape=[None, None], dtype=tf.int32, name='seg_ph')
img = tf.image.resize_bilinear(tf.expand_dims(self.image_ph, 0),
(img_size, img_size))
self.net.create_trunk(img)
if args.detect:
self.net.create_multibox_head(self.loader.num_classes)
confidence = tf.nn.softmax(tf.squeeze(self.net.outputs['confidence']))
location = tf.squeeze(self.net.outputs['location'])
self.nms(location, confidence, self.bboxer.tiling)
if args.segment:
self.net.create_segmentation_head(self.loader.num_classes)
self.segmentation = self.net.outputs['segmentation']
seg_shape = tf.shape(self.image_ph)[:2]
self.segmentation = tf.image.resize_bilinear(self.segmentation, seg_shape)
self.segmentation = tf.cast(tf.argmax(tf.squeeze(self.segmentation), axis=-1), tf.int32)
self.segmentation = tf.reshape(self.segmentation, seg_shape)
self.segmentation.set_shape([None, None])
if not self.no_gt:
easy_mask = self.seg_ph <= self.loader.num_classes
predictions = tf.boolean_mask(self.segmentation, easy_mask)
labels = tf.boolean_mask(self.seg_ph, easy_mask)
self.mean_iou, self.iou_update = mean_iou(predictions, labels, self.loader.num_classes)
else:
self.mean_iou = tf.constant(0)
self.iou_update = tf.constant(0)
评论列表
文章目录