def _aggre(self, q_aware_reps, a_aware_reps):
"""
Aggregation Layer handle
Args:
q_aware_reps - [batch_size, question_len, 11*mp_dim+6]
a_aware_reps - [batch_size, answer_len, 11*mp_dim+6]
Return:
size - [batch_size, aggregation_lstm_dim*4]
"""
_aggres = []
_, (q_hidden, _) = self.aggre_lstm(q_aware_reps)
_, (a_hidden, _) = self.aggre_lstm(a_aware_reps)
# [batch_size, aggregation_lstm_dim]
_aggres.append(q_hidden[-2])
_aggres.append(q_hidden[-1])
_aggres.append(a_hidden[-2])
_aggres.append(a_hidden[-1])
return torch.cat(_aggres, dim=1)
评论列表
文章目录