def forward(self, encoded_question, question_length, encoded_support, support_length,
correct_start, answer2question, is_eval):
# casting
long_tensor = torch.cuda.LongTensor if encoded_question.is_cuda else torch.LongTensor
answer2question = answer2question.type(long_tensor)
# computing single time attention over question
attention_scores = self._linear_question_attention(encoded_question)
q_mask = misc.mask_for_lengths(question_length)
attention_scores = attention_scores.squeeze(2) + q_mask
question_attention_weights = F.softmax(attention_scores)
question_state = torch.matmul(question_attention_weights.unsqueeze(1),
encoded_question).squeeze(1)
# Prediction
# start
start_input = torch.cat([question_state.unsqueeze(1) * encoded_support, encoded_support], 2)
q_start_state = self._linear_q_start(start_input) + self._linear_q_start_q(question_state).unsqueeze(1)
start_scores = self._linear_start_scores(F.relu(q_start_state)).squeeze(2)
support_mask = misc.mask_for_lengths(support_length)
start_scores = start_scores + support_mask
_, predicted_start_pointer = start_scores.max(1)
def align(t):
return torch.index_select(t, 0, answer2question)
if is_eval:
start_pointer = predicted_start_pointer
else:
# use correct start during training, because p(end|start) should be optimized
start_pointer = correct_start.type(long_tensor)
predicted_start_pointer = align(predicted_start_pointer)
start_scores = align(start_scores)
start_input = align(start_input)
encoded_support = align(encoded_support)
question_state = align(question_state)
support_mask = align(support_mask)
# end
u_s = []
for b, p in enumerate(start_pointer):
u_s.append(encoded_support[b, p.data[0]])
u_s = torch.stack(u_s)
end_input = torch.cat([encoded_support * u_s.unsqueeze(1), start_input], 2)
q_end_state = self._linear_q_end(end_input) + self._linear_q_end_q(question_state).unsqueeze(1)
end_scores = self._linear_end_scores(F.relu(q_end_state)).squeeze(2)
end_scores = end_scores + support_mask
max_support = support_length.max().data[0]
if is_eval:
end_scores += misc.mask_for_lengths(start_pointer, max_support, mask_right=False)
_, predicted_end_pointer = end_scores.max(1)
return start_scores, end_scores, predicted_start_pointer, predicted_end_pointer
评论列表
文章目录