base_layer.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:torch_light 作者: ne7ermore 项目源码 文件源码
def forward(self, cont_repres, other_cont_repres):
        """
        Args:
            cont_repres - [batch_size, this_len, context_lstm_dim]
            other_cont_repres - [batch_size, other_len, context_lstm_dim]
        Return:
            size - [bsz, this_len, mp_dim*2]
        """
        bsz = cont_repres.size(0)
        this_len = cont_repres.size(1)
        other_len = other_cont_repres.size(1)

        cont_repres = cont_repres.view(-1, self.cont_dim)
        other_cont_repres = other_cont_repres.view(-1, self.cont_dim)

        cont_repres = multi_perspective_expand_for_2D(cont_repres, self.weight)
        other_cont_repres = multi_perspective_expand_for_2D(other_cont_repres, self.weight)

        cont_repres = cont_repres.view(bsz, this_len, self.mp_dim, self.cont_dim)
        other_cont_repres = other_cont_repres.view(bsz, other_len, self.mp_dim, self.cont_dim)

        # [bsz, this_len, 1, self.mp_dim, self.cont_dim]
        cont_repres = cont_repres.unsqueeze(2)
        # [bsz, 1, other_len, self.mp_dim, self.cont_dim]
        other_cont_repres = other_cont_repres.unsqueeze(1)

        # [bsz, this_len, other_len, self.mp_dim]fanruan
        simi = cosine_similarity(cont_repres, other_cont_repres, cont_repres.dim()-1)

        t_max, _ = simi.max(2)
        t_mean = simi.mean(2)
        return torch.cat((t_max, t_mean), 2)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号