def forward(self, q_input, a_input, drop_rate):
"""
input -> embedding_layer -> multi_cnn_layer -> interact_layer -> batchnorm_layer -> mlp_layer
:param q_input: question sentence vec
:param a_input: answer sentence vec
:param: drop_rate: dropout rate
:return:
"""
q_input_emb = torch.unsqueeze(self.embedding(q_input), dim=1)
a_input_emb = torch.unsqueeze(self.embedding(a_input), dim=1)
q_vec, a_vec = self.inception_module_layers(q_input_emb, a_input_emb)
qa_vec = self.interact_layer(q_vec, a_vec)
bn_vec = self.bn_layer(qa_vec)
prop, cate = self.mlp(bn_vec, drop_rate)
return prop, cate
评论列表
文章目录