def colors_to_dimensions(image_tensor, colors):
logger = get_logger()
single_label_tensors = []
for single_label_color in colors:
is_color = tf.reduce_all(
tf.equal(image_tensor, single_label_color),
axis=-1
)
single_label_tensor = tf.where(
is_color,
tf.fill(is_color.shape, 1.0),
tf.fill(is_color.shape, 0.0)
)
single_label_tensors.append(single_label_tensor)
return tf.stack(single_label_tensors, axis=-1)
评论列表
文章目录