varlen_support.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:RNNVis 作者: myaooo 项目源码 文件源码
def sequence_length(sequence):
    """
    Get the length tensor of a batched_sequence
        when embedding, or say, input sequence is a 3D tensor, the empty part should be filled with 0.s
        whe word_id, or say, input sequence is a 2D tensor, the empty part should be filled with -1s
    :param sequence: a Tensor of shape [batch_size, max_length(, embedding_size)]
    :return: a 1D Tensor of shape (batch_size,) representing the length of the sequence
    """
    embedding = len(sequence.get_shape()) == 3
    if embedding:
        # zeros will be 0., others will be 1.
        used = tf.sign(tf.reduce_max(tf.abs(sequence), axis=2))
    else:
        # -1 will be 0, others will be 1.
        used = tf.sign(sequence+1)
    length = tf.reduce_sum(used, axis=1)
    length = tf.cast(length, tf.int32)
    return length
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号