model.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:MP-CNN-Variants 作者: tuzhucheng 项目源码 文件源码
def forward(self, sent1_idx, sent2_idx, ext_feats=None):
        # Select embedding
        sent1 = self.embedding(sent1_idx).transpose(1, 2)
        sent2 = self.embedding(sent2_idx).transpose(1, 2)

        # Sentence modeling module
        sent1_block_a, sent1_block_b = self._get_blocks_for_sentence(sent1)
        sent2_block_a, sent2_block_b = self._get_blocks_for_sentence(sent2)

        # Similarity measurement layer
        feat_h = self._algo_1_horiz_comp(sent1_block_a, sent2_block_a)
        feat_v = self._algo_2_vert_comp(sent1_block_a, sent2_block_a, sent1_block_b, sent2_block_b)
        combined_feats = [feat_h, feat_v, ext_feats] if self.ext_feats else [feat_h, feat_v]
        feat_all = torch.cat(combined_feats, dim=1)

        preds = self.final_layers(feat_all)
        return preds
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号