def multilabel_image_to_class(label_image: tf.Tensor, classes_file: str) -> tf.Tensor:
classes_color_values, colors_labels = get_classes_color_from_file_multilabel(classes_file)
# Convert label_image [H,W,3] to the classes [H,W,C],int32 according to the classes [C,3]
with tf.name_scope('LabelAssign'):
if len(label_image.get_shape()) == 3:
diff = tf.cast(label_image[:, :, None, :], tf.float32) - tf.constant(classes_color_values[None, None, :, :]) # [H,W,C,3]
elif len(label_image.get_shape()) == 4:
diff = tf.cast(label_image[:, :, :, None, :], tf.float32) - tf.constant(
classes_color_values[None, None, None, :, :]) # [B,H,W,C,3]
else:
raise NotImplementedError('Length is : {}'.format(len(label_image.get_shape())))
pixel_class_diff = tf.reduce_sum(tf.square(diff), axis=-1) # [H,W,C] or [B,H,W,C]
class_label = tf.argmin(pixel_class_diff, axis=-1) # [H,W] or [B,H,W]
return tf.gather(colors_labels, class_label) > 0
评论列表
文章目录