rasor_model.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:squad_rasor_nn 作者: hsgodhia 项目源码 文件源码
def _span_sums(self, p_lens, stt, end, max_p_len, batch_size, dim, max_ans_len):
        # stt       (max_p_len, batch_size, dim)
        # end       (max_p_len, batch_size, dim)
        # p_lens    (batch_size,)

        max_ans_len_range = torch.from_numpy(np.arange(max_ans_len))
        max_ans_len_range = max_ans_len_range.unsqueeze(0)  # (1, max_ans_len) is a vector like [0,1,2,3,4....,max_ans_len-1]
        offsets = torch.from_numpy(np.arange(max_p_len))
        offsets = offsets.unsqueeze(0)  # (1, max_p_len) is a vector like (0,1,2,3,4....max_p_len-1)
        offsets = offsets.transpose(0, 1)  # (max_p_len, 1) is row vector now like [0/1/2/3...max_p_len-1]

        end_idxs = max_ans_len_range.expand(offsets.size(0), max_ans_len_range.size(1)) + offsets.expand(offsets.size(0), max_ans_len_range.size(1))
        #pdb.set_trace()
        end_idxs_flat = end_idxs.view(-1, 1).squeeze(1)  # (max_p_len*max_ans_len, )
        # note: this is not modeled as tensor of size (SZ, 1) but vector of SZ size
        zero_t = torch.zeros(max_ans_len - 1, batch_size, dim)
        if torch.cuda.is_available():
            zero_t = zero_t.cuda(0)
            end_idxs_flat = end_idxs_flat.cuda(0)

        end_padded = torch.cat((end, Variable(zero_t)), 0)
        end_structed = end_padded[end_idxs_flat]  # (max_p_len*max_ans_len, batch_size, dim)
        end_structed = end_structed.view(max_p_len, max_ans_len, batch_size, dim)
        stt_shuffled = stt.unsqueeze(1)  # stt (max_p_len, 1, batch_size, dim)

        # since the FFNN(h_a) * W we expand h_a as [p_start, p_end]*[w_1 w_2] so this reduces to p_start*w_1 + p_end*w_2
        # now we can reuse the operations, we compute only once
        span_sums = stt_shuffled.expand(max_p_len, max_ans_len, batch_size, dim) + end_structed # (max_p_len, max_ans_len, batch_size, dim)

        span_sums_reshapped = span_sums.permute(2, 0, 1, 3).contiguous().view(batch_size, max_ans_len * max_p_len, dim)

        p_lens_shuffled = p_lens.unsqueeze(1)
        end_idxs_flat_shuffled = end_idxs_flat.unsqueeze(0)

        span_masks_reshaped = Variable(end_idxs_flat_shuffled.expand(p_lens_shuffled.size(0), end_idxs_flat_shuffled.size(1))) < p_lens_shuffled.expand(p_lens_shuffled.size(0), end_idxs_flat_shuffled.size(1))
        span_masks_reshaped = span_masks_reshaped.float()

        return span_sums_reshapped, span_masks_reshaped

    #q_align_weights = self.softmax(q_align_mask_scores)  # (batch_size, max_p_len, max_q_len)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号