def _word_repre_layer(self, input):
"""
args:
- input: (q_sentence, q_words)|(a_sentence, a_words)
q_sentence - [batch_size, sent_length]
q_words - [batch_size, sent_length, words_len]
return:
- output: [batch_size, sent_length, context_dim]
"""
sentence, words = input
# [batch_size, sent_length, corpus_emb_dim]
s_encode = self.corpus_emb(sentence)
# [batch_size, sent_length, word_lstm_dim]
w_encode = self._word_repre_forward(words)
w_encode = F.dropout(w_encode, p=self.dropout, training=True, inplace=False)
out = torch.cat((s_encode, w_encode), 2)
return out
评论列表
文章目录