def preprocess_labels(label):
"""Preprocess the labels to adapt them to the loss computation requirements
Args:
Label corresponding to the input image (W,H) numpy array
Returns:
Label ready to compute the loss (1,W,H,1)
"""
if type(label) is not np.ndarray:
label = np.array(Image.open(label).split()[0], dtype=np.uint8)
max_mask = np.max(label) * 0.5
label = np.greater(label, max_mask)
label = np.expand_dims(np.expand_dims(label, axis=0), axis=3)
# label = tf.cast(np.array(label), tf.float32)
# max_mask = tf.multiply(tf.reduce_max(label), 0.5)
# label = tf.cast(tf.greater(label, max_mask), tf.float32)
# label = tf.expand_dims(tf.expand_dims(label, 0), 3)
return label
评论列表
文章目录