sparse.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:cxflow-tensorflow 作者: Cognexa 项目源码 文件源码
def dense_to_sparse(inputs: tf.Tensor, mask: Optional[tf.Tensor]=None) -> tf.SparseTensor:
    """
    Convert the given ``inputs`` tensor to a ``SparseTensor`` of its non-zero values.

    Optionally, use the given mask tensor for determining the values to be included in the ``SparseTensor``.

    :param inputs: input dense tensor
    :param mask: optional mask tensor
    :return: sparse tensor of non-zero (or masked) values
    """
    idx = tf.where(tf.not_equal((mask if mask is not None else inputs), 0))
    return tf.SparseTensor(idx, tf.gather_nd(inputs, idx), tf.shape(inputs, out_type=tf.int64))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号