metric_ops_test.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:lsdc 作者: febert 项目源码 文件源码
def _binary_3d_label_to_sparse_value(labels):
  """Convert dense 3D binary indicator tensor to sparse tensor.

  Only 1 values in `labels` are included in result.

  Args:
    labels: Dense 2D binary indicator tensor.

  Returns:
    `SparseTensorValue` whose values are indices along the last dimension of
    `labels`.
  """
  indices = []
  values = []
  for d0, labels_d0 in enumerate(labels):
    for d1, labels_d1 in enumerate(labels_d0):
      d2 = 0
      for class_id, label in enumerate(labels_d1):
        if label == 1:
          values.append(class_id)
          indices.append([d0, d1, d2])
          d2 += 1
        else:
          assert label == 0
  shape = [len(labels), len(labels[0]), len(labels[0][0])]
  return tf.SparseTensorValue(
      np.array(indices, np.int64),
      np.array(values, np.int64),
      np.array(shape, np.int64))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号