sparse_ops.py 文件源码

python
阅读 16 收藏 0 点赞 0 评论 0

项目:lsdc 作者: febert 项目源码 文件源码
def dense_to_sparse_tensor(dense_tensor, ignore_value=None):
  """Converts a dense Tensor to a SparseTensor, dropping ignore_value cells.

  Args:
    dense_tensor: A `Tensor`.
    ignore_value: Entries in `dense_tensor` equal to this value will be
      absent from the return `SparseTensor`. If `None`, default value of
      dense_tensor's dtype will be used (e.g. '' for `str`, 0 for `int`).

  Returns:
    A `SparseTensor` with the same shape as `dense_tensor`.

  Raises:
    ValueError: when `dense_tensor`'s rank is `None`.
  """
  with ops.name_scope("DenseToSparseTensor"):
    dense_t = ops.convert_to_tensor(dense_tensor)
    if dense_t.get_shape().ndims is None:
      # TODO(b/32318825): Implement dense_to_sparse_tensor for undefined rank.
      raise ValueError("dense_tensor.get_shape() should be defined, got None.")
    if ignore_value is None:
      if dense_t.dtype == dtypes.string:
        # Exception due to TF strings are converted to numpy objects by default.
        ignore_value = ""
      else:
        ignore_value = dense_t.dtype.as_numpy_dtype()
    dense_shape = math_ops.cast(array_ops.shape(dense_t), dtypes.int64)
    indices = array_ops.where(
        math_ops.not_equal(dense_t, math_ops.cast(ignore_value, dense_t.dtype)))
    index_dims = len(dense_t.get_shape())
    # Flattens the tensor and indices for use with gather.
    flat_tensor = array_ops.reshape(dense_t, [-1])
    flat_indices = indices[:, index_dims - 1]
    # Computes the correct flattened indices for 2d (or higher) tensors.
    if index_dims > 1:
      higher_dims = indices[:, :index_dims - 1]
      shape_multipliers = array_ops.pack(
          _multiplier_helper(array_ops.unpack(dense_shape)[1:]))
      offsets = math_ops.reduce_sum(
          math_ops.mul(higher_dims, shape_multipliers), reduction_indices=[1])
      flat_indices = math_ops.add(flat_indices, offsets)
    values = array_ops.gather(flat_tensor, flat_indices)
    return sparse_tensor.SparseTensor(indices, values, dense_shape)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号