static_head.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:pytorch-dnc 作者: jingweiz 项目源码 文件源码
def forward(self, hidden_vb, memory_vb):
        # outputs for computing addressing for heads
        # NOTE: to be consistent w/ the dnc paper, we use
        # NOTE: sigmoid to constrain to [0, 1]
        # NOTE: oneplus to constrain to [1, +inf]
        self.key_vb   = F.tanh(self.hid_2_key(hidden_vb)).view(-1, self.num_heads, self.mem_wid)    # TODO: relu to bias the memory to store positive values ??? check again
        self.beta_vb  = F.softplus(self.hid_2_beta(hidden_vb)).view(-1, self.num_heads, 1)          # beta >=1: https://github.com/deepmind/dnc/issues/9
        self.gate_vb  = F.sigmoid(self.hid_2_gate(hidden_vb)).view(-1, self.num_heads, 1)           # gate /in (0, 1): interpolation gate, blend wl_{t-1} & wc
        self.shift_vb = F.softmax(self.hid_2_shift(hidden_vb).view(-1, self.num_heads, self.num_allowed_shifts).transpose(0, 2)).transpose(0, 2)    # shift: /sum=1
        self.gamma_vb = (1. + F.softplus(self.hid_2_gamma(hidden_vb))).view(-1, self.num_heads, 1)  # gamma >= 1: sharpen the final weights

        # now we compute the addressing mechanism
        self._content_focus(memory_vb)
        self._location_focus()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号