def forward(self, h, u, h_mask=None, u_mask=None):
config = self.config
if config.q2c_att or config.c2q_att:
u_a, h_a = self.bi_attention(h, u, h_mask=h_mask, u_mask=u_mask)
'''
u_a: [N, M, JX, d]
h_a: [N, M, d]
'''
else:
print("AttentionLayer: q2c_att or c2q_att False not supported")
if config.q2c_att:
p0 = torch.cat([h, u_a, torch.mul(h, u_a), torch.mul(h, h_a)], 3)
else:
print("AttentionLayer: q2c_att False not supported")
return p0
评论列表
文章目录