def _shift(self, wg_vb, shift_vb):
"""
variables needed:
wg_vb: [batch_size x num_heads x mem_hei]
shift_vb: [batch_size x num_heads x num_allowed_shifts]
-> sum=1; the shift weight vector
returns:
ws_vb: [batch_size x num_heads x mem_hei]
-> the attention weight by location focus
"""
batch_size = wg_vb.size(0)
input_dim = wg_vb.size(2); assert input_dim == self.mem_hei
filter_dim = shift_vb.size(2); assert filter_dim == self.num_allowed_shifts
ws_vb = None
for i in range(batch_size): # for each head in each batch, the kernel is different ... seems there's no other way by doing the loop here
for j in range(self.num_heads):
ws_tmp_vb = F.conv1d(wg_vb[i][j].unsqueeze(0).unsqueeze(0).repeat(1, 1, 3),
shift_vb[i][j].unsqueeze(0).unsqueeze(0).contiguous(),
padding=filter_dim//2)[:, :, input_dim:(2*input_dim)]
if ws_vb is None:
ws_vb = ws_tmp_vb
else:
ws_vb = torch.cat((ws_vb, ws_tmp_vb), 0)
ws_vb = ws_vb.view(-1, self.num_heads, self.mem_hei)
return ws_vb
评论列表
文章目录