models.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:deeppavlov 作者: deepmipt 项目源码 文件源码
def flatten_emb_by_sentence(self, emb, text_len_mask):
        """
        Create boolean mask for emb tensor.
        Args:
            emb: Some embeddings tensor with rank 2 or 3
            text_len_mask: A mask tensor representing the first N positions of each row.

        Returns: emb tensor after mask applications.

        """
        num_sentences = tf.shape(emb)[0]
        max_sentence_length = tf.shape(emb)[1]

        emb_rank = len(emb.get_shape())
        if emb_rank == 2:
            flattened_emb = tf.reshape(emb, [num_sentences * max_sentence_length])
        elif emb_rank == 3:
            flattened_emb = tf.reshape(emb, [num_sentences * max_sentence_length, utils.shape(emb, 2)])
        else:
            raise ValueError("Unsupported rank: {}".format(emb_rank))
        return tf.boolean_mask(flattened_emb, text_len_mask)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号