def flatten_emb_by_sentence(self, emb, text_len_mask):
"""
Create boolean mask for emb tensor.
Args:
emb: Some embeddings tensor with rank 2 or 3
text_len_mask: A mask tensor representing the first N positions of each row.
Returns: emb tensor after mask applications.
"""
num_sentences = tf.shape(emb)[0]
max_sentence_length = tf.shape(emb)[1]
emb_rank = len(emb.get_shape())
if emb_rank == 2:
flattened_emb = tf.reshape(emb, [num_sentences * max_sentence_length])
elif emb_rank == 3:
flattened_emb = tf.reshape(emb, [num_sentences * max_sentence_length, utils.shape(emb, 2)])
else:
raise ValueError("Unsupported rank: {}".format(emb_rank))
return tf.boolean_mask(flattened_emb, text_len_mask)
评论列表
文章目录