static_head.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:pytorch-dnc 作者: jingweiz 项目源码 文件源码
def _shift(self, wg_vb, shift_vb):
        """
        variables needed:
            wg_vb:    [batch_size x num_heads x mem_hei]
            shift_vb: [batch_size x num_heads x num_allowed_shifts]
                   -> sum=1; the shift weight vector
        returns:
            ws_vb:    [batch_size x num_heads x mem_hei]
                   -> the attention weight by location focus
        """
        batch_size = wg_vb.size(0)
        input_dim = wg_vb.size(2); assert input_dim == self.mem_hei
        filter_dim = shift_vb.size(2); assert filter_dim == self.num_allowed_shifts

        ws_vb = None
        for i in range(batch_size): # for each head in each batch, the kernel is different ... seems there's no other way by doing the loop here
            for j in range(self.num_heads):
                ws_tmp_vb = F.conv1d(wg_vb[i][j].unsqueeze(0).unsqueeze(0).repeat(1, 1, 3),
                                     shift_vb[i][j].unsqueeze(0).unsqueeze(0).contiguous(),
                                     padding=filter_dim//2)[:, :, input_dim:(2*input_dim)]
                if ws_vb is None:
                    ws_vb = ws_tmp_vb
                else:
                    ws_vb = torch.cat((ws_vb, ws_tmp_vb), 0)
        ws_vb = ws_vb.view(-1, self.num_heads, self.mem_hei)
        return ws_vb
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号