def forward(self, q_corpora, q_words, a_corpora, a_words):
"""
Module main forward
"""
# Step 1 - Get Mask from q_corpora and a_corpora
self.q_mask = q_corpora.ge(PF_POS)
self.a_mask = a_corpora.ge(PF_POS)
# Step 2 - Word Representation Layer
self.q_repres = self._word_repre_layer((q_corpora, q_words))
self.a_repres = self._word_repre_layer((a_corpora, a_words))
# Step 3 - Cosine Similarity and mask
iqr_temp = self.q_repres.unsqueeze(1) # [bsz, 1, q_len, context_dim]
ipr_temp = self.a_repres.unsqueeze(2) # [bsz, a_len, 1, context_dim]
# [bsz, a_len, q_len]
simi = F.cosine_similarity(iqr_temp, ipr_temp, dim=3)
simi_mask = self._cosine_similarity_mask(simi)
# Step 4 - Matching Layer
q_aware_reps, a_aware_reps = self._bilateral_match(simi_mask)
q_aware_reps = F.dropout(q_aware_reps, p=self.dropout)
a_aware_reps = F.dropout(a_aware_reps, p=self.dropout)
# Step 5 - Aggregation Layer
aggre = self._aggre(q_aware_reps, a_aware_reps)
# Step 6 - Prediction Layer
predict = F.tanh(self.l1(aggre))
predict = F.dropout(predict, p=self.dropout)
return F.softmax(self.l2(predict))
评论列表
文章目录