layers.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:BiDAF-PyTorch 作者: kelayamatoz 项目源码 文件源码
def forward(self, h, u, h_mask=None, u_mask=None):
        config = self.config
        if config.q2c_att or config.c2q_att:
            u_a, h_a = self.bi_attention(h, u, h_mask=h_mask, u_mask=u_mask)
            '''
            u_a: [N, M, JX, d]
            h_a: [N, M, d]
            '''
        else:
            print("AttentionLayer: q2c_att or c2q_att False not supported")

        if config.q2c_att:
            p0 = torch.cat([h, u_a, torch.mul(h, u_a), torch.mul(h, h_a)], 3)
        else:
            print("AttentionLayer: q2c_att False not supported")

        return p0
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号