def one_hot_mask(labels, num_classes, scope=None):
"""Compute 1-hot encodings for masks.
Given a label image, this computes the one hot encoding at
each pixel.
Args:
labels: (batch_size, width, height, 1) tensor containing labels.
num_classes: number of classes
scope: optional scope name
Returns:
Tensor of shape (batch_size, width, height, num_classes) with
a 1-hot encoding.
"""
with ops.name_scope(scope, "OneHotMask", [labels]):
height, width, depth = _shape(labels)
assert depth == 1
sparse_labels = math_ops.to_int32(array_ops.reshape(labels, [-1, 1]))
sparse_size, _ = _shape(sparse_labels)
indices = array_ops.reshape(math_ops.range(0, sparse_size, 1), [-1, 1])
concated = array_ops.concat([indices, sparse_labels], 1)
dense_result = sparse_ops.sparse_to_dense(concated,
[sparse_size, num_classes], 1.0,
0.0)
result = array_ops.reshape(dense_result, [height, width, num_classes])
return result
misc.py 文件源码
python
阅读 25
收藏 0
点赞 0
评论 0
评论列表
文章目录