model.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:torch_light 作者: ne7ermore 项目源码 文件源码
def _aggre(self, q_aware_reps, a_aware_reps):
        """
        Aggregation Layer handle
        Args:
            q_aware_reps - [batch_size, question_len, 11*mp_dim+6]
            a_aware_reps - [batch_size, answer_len, 11*mp_dim+6]
        Return:
            size - [batch_size, aggregation_lstm_dim*4]
        """
        _aggres = []
        _, (q_hidden, _) = self.aggre_lstm(q_aware_reps)
        _, (a_hidden, _) = self.aggre_lstm(a_aware_reps)

        # [batch_size, aggregation_lstm_dim]
        _aggres.append(q_hidden[-2])
        _aggres.append(q_hidden[-1])
        _aggres.append(a_hidden[-2])
        _aggres.append(a_hidden[-1])
        return torch.cat(_aggres, dim=1)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号