model.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:torch_light 作者: ne7ermore 项目源码 文件源码
def forward(self, q_corpora, q_words, a_corpora, a_words):
        """
        Module main forward
        """
        # Step 1 - Get Mask from q_corpora and a_corpora
        self.q_mask = q_corpora.ge(PF_POS)
        self.a_mask = a_corpora.ge(PF_POS)

        # Step 2 - Word Representation Layer
        self.q_repres = self._word_repre_layer((q_corpora, q_words))
        self.a_repres = self._word_repre_layer((a_corpora, a_words))

        # Step 3 - Cosine Similarity and mask
        iqr_temp = self.q_repres.unsqueeze(1) # [bsz, 1, q_len, context_dim]
        ipr_temp = self.a_repres.unsqueeze(2) # [bsz, a_len, 1, context_dim]

        # [bsz, a_len, q_len]
        simi = F.cosine_similarity(iqr_temp, ipr_temp, dim=3)
        simi_mask = self._cosine_similarity_mask(simi)

        # Step 4 - Matching Layer
        q_aware_reps, a_aware_reps = self._bilateral_match(simi_mask)
        q_aware_reps = F.dropout(q_aware_reps, p=self.dropout)
        a_aware_reps = F.dropout(a_aware_reps, p=self.dropout)

        # Step 5 - Aggregation Layer
        aggre = self._aggre(q_aware_reps, a_aware_reps)

        # Step 6 - Prediction Layer
        predict = F.tanh(self.l1(aggre))
        predict = F.dropout(predict, p=self.dropout)
        return F.softmax(self.l2(predict))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号