def attention_mul(rnn_outputs, att_weights):
attn_vectors = None
for i in range(rnn_outputs.size(0)):
h_i = rnn_outputs[i]
a_i = att_weights[i].unsqueeze(1).expand_as(h_i)
h_i = a_i * h_i
h_i = h_i.unsqueeze(0)
if(attn_vectors is None):
attn_vectors = h_i
else:
attn_vectors = torch.cat((attn_vectors,h_i),0)
return torch.sum(attn_vectors, 0)
评论列表
文章目录