def forward(self, sent1_idx, sent2_idx, ext_feats=None):
# Select embedding
sent1 = self.embedding(sent1_idx).transpose(1, 2)
sent2 = self.embedding(sent2_idx).transpose(1, 2)
# Sentence modeling module
sent1_block_a, sent1_block_b = self._get_blocks_for_sentence(sent1)
sent2_block_a, sent2_block_b = self._get_blocks_for_sentence(sent2)
# Similarity measurement layer
feat_h = self._algo_1_horiz_comp(sent1_block_a, sent2_block_a)
feat_v = self._algo_2_vert_comp(sent1_block_a, sent2_block_a, sent1_block_b, sent2_block_b)
combined_feats = [feat_h, feat_v, ext_feats] if self.ext_feats else [feat_h, feat_v]
feat_all = torch.cat(combined_feats, dim=1)
preds = self.final_layers(feat_all)
return preds
评论列表
文章目录