def retrieve_seq_length_op2(data):
"""An op to compute the length of a sequence, from input shape of [batch_size, n_step(max)],
it can be used when the features of padding (on right hand side) are all zeros.
Parameters
-----------
data : tensor
[batch_size, n_step(max)] with zero padding on right hand side.
Examples
--------
>>> data = [[1,2,0,0,0],
... [1,2,3,0,0],
... [1,2,6,1,0]]
>>> o = retrieve_seq_length_op2(data)
>>> sess = tf.InteractiveSession()
>>> tl.layers.initialize_global_variables(sess)
>>> print(o.eval())
... [2 3 4]
"""
return tf.reduce_sum(tf.cast(tf.greater(data, tf.zeros_like(data)), tf.int32), 1)
评论列表
文章目录