base_layer.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:torch_light 作者: ne7ermore 项目源码 文件源码
def forward(self, repres, max_att):
        """
        Args:
            repres - [bsz, a_len|q_len, cont_dim]
            max_att - [bsz, q_len|a_len, cont_dim]
        Return:
            size - [bsz, sentence_len, mp_dim]
        """
        bsz = repres.size(0)
        sent_len = repres.size(1)

        repres = repres.view(-1, self.cont_dim)
        max_att = max_att.view(-1, self.cont_dim)
        repres = multi_perspective_expand_for_2D(repres, self.weight)
        max_att = multi_perspective_expand_for_2D(max_att, self.weight)
        temp = cosine_similarity(repres, max_att, repres.dim()-1)

        return temp.view(bsz, sent_len, self.mp_dim)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号