def _span_sums(self, p_lens, stt, end, max_p_len, batch_size, dim, max_ans_len):
# stt (max_p_len, batch_size, dim)
# end (max_p_len, batch_size, dim)
# p_lens (batch_size,)
max_ans_len_range = torch.from_numpy(np.arange(max_ans_len))
max_ans_len_range = max_ans_len_range.unsqueeze(0) # (1, max_ans_len) is a vector like [0,1,2,3,4....,max_ans_len-1]
offsets = torch.from_numpy(np.arange(max_p_len))
offsets = offsets.unsqueeze(0) # (1, max_p_len) is a vector like (0,1,2,3,4....max_p_len-1)
offsets = offsets.transpose(0, 1) # (max_p_len, 1) is row vector now like [0/1/2/3...max_p_len-1]
end_idxs = max_ans_len_range.expand(offsets.size(0), max_ans_len_range.size(1)) + offsets.expand(offsets.size(0), max_ans_len_range.size(1))
#pdb.set_trace()
end_idxs_flat = end_idxs.view(-1, 1).squeeze(1) # (max_p_len*max_ans_len, )
# note: this is not modeled as tensor of size (SZ, 1) but vector of SZ size
zero_t = torch.zeros(max_ans_len - 1, batch_size, dim)
if torch.cuda.is_available():
zero_t = zero_t.cuda(0)
end_idxs_flat = end_idxs_flat.cuda(0)
end_padded = torch.cat((end, Variable(zero_t)), 0)
end_structed = end_padded[end_idxs_flat] # (max_p_len*max_ans_len, batch_size, dim)
end_structed = end_structed.view(max_p_len, max_ans_len, batch_size, dim)
stt_shuffled = stt.unsqueeze(1) # stt (max_p_len, 1, batch_size, dim)
# since the FFNN(h_a) * W we expand h_a as [p_start, p_end]*[w_1 w_2] so this reduces to p_start*w_1 + p_end*w_2
# now we can reuse the operations, we compute only once
span_sums = stt_shuffled.expand(max_p_len, max_ans_len, batch_size, dim) + end_structed # (max_p_len, max_ans_len, batch_size, dim)
span_sums_reshapped = span_sums.permute(2, 0, 1, 3).contiguous().view(batch_size, max_ans_len * max_p_len, dim)
p_lens_shuffled = p_lens.unsqueeze(1)
end_idxs_flat_shuffled = end_idxs_flat.unsqueeze(0)
span_masks_reshaped = Variable(end_idxs_flat_shuffled.expand(p_lens_shuffled.size(0), end_idxs_flat_shuffled.size(1))) < p_lens_shuffled.expand(p_lens_shuffled.size(0), end_idxs_flat_shuffled.size(1))
span_masks_reshaped = span_masks_reshaped.float()
return span_sums_reshapped, span_masks_reshaped
#q_align_weights = self.softmax(q_align_mask_scores) # (batch_size, max_p_len, max_q_len)
评论列表
文章目录