fastqa.py 文件源码

python
阅读 49 收藏 0 点赞 0 评论 0

项目:jack 作者: uclmr 项目源码 文件源码
def forward(self, encoded_question, question_length, encoded_support, support_length,
                correct_start, answer2question, is_eval):
        # casting
        long_tensor = torch.cuda.LongTensor if encoded_question.is_cuda else torch.LongTensor
        answer2question = answer2question.type(long_tensor)

        # computing single time attention over question
        attention_scores = self._linear_question_attention(encoded_question)
        q_mask = misc.mask_for_lengths(question_length)
        attention_scores = attention_scores.squeeze(2) + q_mask
        question_attention_weights = F.softmax(attention_scores)
        question_state = torch.matmul(question_attention_weights.unsqueeze(1),
                                      encoded_question).squeeze(1)

        # Prediction
        # start
        start_input = torch.cat([question_state.unsqueeze(1) * encoded_support, encoded_support], 2)

        q_start_state = self._linear_q_start(start_input) + self._linear_q_start_q(question_state).unsqueeze(1)
        start_scores = self._linear_start_scores(F.relu(q_start_state)).squeeze(2)

        support_mask = misc.mask_for_lengths(support_length)
        start_scores = start_scores + support_mask
        _, predicted_start_pointer = start_scores.max(1)

        def align(t):
            return torch.index_select(t, 0, answer2question)

        if is_eval:
            start_pointer = predicted_start_pointer
        else:
            # use correct start during training, because p(end|start) should be optimized
            start_pointer = correct_start.type(long_tensor)
            predicted_start_pointer = align(predicted_start_pointer)
            start_scores = align(start_scores)
            start_input = align(start_input)
            encoded_support = align(encoded_support)
            question_state = align(question_state)
            support_mask = align(support_mask)

        # end
        u_s = []
        for b, p in enumerate(start_pointer):
            u_s.append(encoded_support[b, p.data[0]])
        u_s = torch.stack(u_s)

        end_input = torch.cat([encoded_support * u_s.unsqueeze(1), start_input], 2)

        q_end_state = self._linear_q_end(end_input) + self._linear_q_end_q(question_state).unsqueeze(1)
        end_scores = self._linear_end_scores(F.relu(q_end_state)).squeeze(2)

        end_scores = end_scores + support_mask

        max_support = support_length.max().data[0]

        if is_eval:
            end_scores += misc.mask_for_lengths(start_pointer, max_support, mask_right=False)

        _, predicted_end_pointer = end_scores.max(1)

        return start_scores, end_scores, predicted_start_pointer, predicted_end_pointer
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号