def remove_padding(self, input_text):
# calculate max length of the input_text
mask = tf.greater_equal(input_text, 0) # true for words false for padding
sequence_length = tf.reduce_sum(tf.cast(mask, tf.int32), 1)
# truncate the input text to max length
max_sequence_length = tf.reduce_max(sequence_length)
input_text_length = tf.shape(input_text)[1]
empty_padding_lenght = input_text_length - max_sequence_length
input_text, _ = tf.split(input_text, [max_sequence_length, empty_padding_lenght], axis=1)
return input_text, sequence_length
text_classification_model_simple.py 文件源码
python
阅读 31
收藏 0
点赞 0
评论 0
评论列表
文章目录