base_layer.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:torch_light 作者: ne7ermore 项目源码 文件源码
def forward(self, cont_repres, other_cont_first):
        """
        Args:
            cont_repres - [batch_size, this_len, context_lstm_dim]
            other_cont_first - [batch_size, context_lstm_dim]
        Return:
            size - [batch_size, this_len, mp_dim]
        """
        def expand(context, weight):
            """
            Args:
                [batch_size, this_len, context_lstm_dim]
                [mp_dim, context_lstm_dim]
            Return:
                [batch_size, this_len, mp_dim, context_lstm_dim]
            """
            # [1, 1, mp_dim, context_lstm_dim]
            weight = weight.unsqueeze(0)
            weight = weight.unsqueeze(0)
            # [batch_size, this_len, 1, context_lstm_dim]
            context = context.unsqueeze(2)
            return torch.mul(context, weight)

        cont_repres = expand(cont_repres, self.weight)

        other_cont_first = multi_perspective_expand_for_2D(other_cont_first, self.weight)
        # [batch_size, 1, mp_dim, context_lstm_dim]
        other_cont_first = other_cont_first.unsqueeze(1)
        return cosine_similarity(cont_repres, other_cont_first, cont_repres.dim()-1)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号