qa_cnn.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:DBQA 作者: nanfeng1101 项目源码 文件源码
def forward(self, q_input, a_input, drop_rate):
        """
        input -> embedding_layer -> multi_cnn_layer -> interact_layer -> batchnorm_layer -> mlp_layer
        :param q_input: question sentence vec
        :param a_input: answer sentence vec
        :param: drop_rate: dropout rate
        :return:
        """
        q_input_emb = torch.unsqueeze(self.embedding(q_input), dim=1)
        a_input_emb = torch.unsqueeze(self.embedding(a_input), dim=1)
        q_vec, a_vec = self.inception_module_layers(q_input_emb, a_input_emb)
        qa_vec = self.interact_layer(q_vec, a_vec)
        bn_vec = self.bn_layer(qa_vec)
        prop, cate = self.mlp(bn_vec, drop_rate)
        return prop, cate
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号