def one_hot_mask(labels, num_classes, scope=None):
"""Compute 1-hot encodings for masks.
Given a label image, this computes the one hot encoding at
each pixel.
Args:
labels: (batch_size, width, height, 1) tensor containing labels.
num_classes: number of classes
scope: optional scope name
Returns:
Tensor of shape (batch_size, width, height, num_classes) with
a 1-hot encoding.
"""
with tf.name_scope(scope, "OneHotMask", [labels]):
height, width, depth = _shape(labels)
assert depth == 1
sparse_labels = tf.to_int32(tf.reshape(labels, [-1, 1]))
sparse_size, _ = _shape(sparse_labels)
indices = tf.reshape(tf.range(0, sparse_size, 1), [-1, 1])
concated = tf.concat(1, [indices, sparse_labels])
dense_result = tf.sparse_to_dense(concated, [sparse_size, num_classes], 1.0,
0.0)
result = tf.reshape(dense_result, [height, width, num_classes])
return result
评论列表
文章目录