misc.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:DeepLearning_VirtualReality_BigData_Project 作者: rashmitripathi 项目源码 文件源码
def one_hot_mask(labels, num_classes, scope=None):
  """Compute 1-hot encodings for masks.

  Given a label image, this computes the one hot encoding at
  each pixel.

  Args:
    labels: (batch_size, width, height, 1) tensor containing labels.
    num_classes: number of classes
    scope: optional scope name

  Returns:
    Tensor of shape (batch_size, width, height, num_classes) with
    a 1-hot encoding.
  """
  with ops.name_scope(scope, "OneHotMask", [labels]):
    height, width, depth = _shape(labels)
    assert depth == 1
    sparse_labels = math_ops.to_int32(array_ops.reshape(labels, [-1, 1]))
    sparse_size, _ = _shape(sparse_labels)
    indices = array_ops.reshape(math_ops.range(0, sparse_size, 1), [-1, 1])
    concated = array_ops.concat([indices, sparse_labels], 1)
    dense_result = sparse_ops.sparse_to_dense(concated,
                                              [sparse_size, num_classes], 1.0,
                                              0.0)
    result = array_ops.reshape(dense_result, [height, width, num_classes])
    return result
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号