def sequence_length(sequence):
"""
Get the length tensor of a batched_sequence
when embedding, or say, input sequence is a 3D tensor, the empty part should be filled with 0.s
whe word_id, or say, input sequence is a 2D tensor, the empty part should be filled with -1s
:param sequence: a Tensor of shape [batch_size, max_length(, embedding_size)]
:return: a 1D Tensor of shape (batch_size,) representing the length of the sequence
"""
embedding = len(sequence.get_shape()) == 3
if embedding:
# zeros will be 0., others will be 1.
used = tf.sign(tf.reduce_max(tf.abs(sequence), axis=2))
else:
# -1 will be 0, others will be 1.
used = tf.sign(sequence+1)
length = tf.reduce_sum(used, axis=1)
length = tf.cast(length, tf.int32)
return length
评论列表
文章目录