dynamic_read_head.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:pytorch-dnc 作者: jingweiz 项目源码 文件源码
def _update_usage(self, hidden_vb, prev_usage_vb):
        """
        calculates the new usage after reading and freeing from memory
        variables needed:
            hidden_vb:     [batch_size x hidden_dim]
            prev_usage_vb: [batch_size x mem_hei]
            free_gate_vb:  [batch_size x num_heads x 1]
            wl_prev_vb:    [batch_size x num_heads x mem_hei]
        returns:
            usage_vb:      [batch_size x mem_hei]
        """
        self.free_gate_vb = F.sigmoid(self.hid_2_free_gate(hidden_vb)).view(-1, self.num_heads, 1)
        free_read_weights_vb = self.free_gate_vb.expand_as(self.wl_prev_vb) * self.wl_prev_vb
        psi_vb = torch.prod(1. - free_read_weights_vb, 1)
        return prev_usage_vb * psi_vb
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号