static_write_head.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:pytorch-dnc 作者: jingweiz 项目源码 文件源码
def _access(self, memory_vb): # write
        """
        variables needed:
            wl_curr_vb: [batch_size x num_heads x mem_hei]
            erase_vb:   [batch_size x num_heads x mem_wid]
                     -> /in (0, 1)
            add_vb:     [batch_size x num_heads x mem_wid]
                     -> w/ no restrictions in range
            memory_vb:  [batch_size x mem_hei x mem_wid]
        returns:
            memory_vb:  [batch_size x mem_hei x mem_wid]
        NOTE: IMPORTANT: https://github.com/deepmind/dnc/issues/10
        """

        # first let's do erasion
        weighted_erase_vb = torch.bmm(self.wl_curr_vb.contiguous().view(-1, self.mem_hei, 1),
                                      self.erase_vb.contiguous().view(-1, 1, self.mem_wid)).view(-1, self.num_heads, self.mem_hei, self.mem_wid)
        keep_vb = torch.prod(1. - weighted_erase_vb, dim=1)
        memory_vb = memory_vb * keep_vb
        # finally let's write (do addition)
        return memory_vb + torch.bmm(self.wl_curr_vb.transpose(1, 2), self.add_vb)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号