static_head.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:pytorch-dnc 作者: jingweiz 项目源码 文件源码
def _content_focus(self, memory_vb):
        """
        variables needed:
            key_vb:    [batch_size x num_heads x mem_wid]
                    -> similarity key vector, to compare to each row in memory
                    -> by cosine similarity
            beta_vb:   [batch_size x num_heads x 1]
                    -> NOTE: refer here: https://github.com/deepmind/dnc/issues/9
                    -> \in (1, +inf) after oneplus(); similarity key strength
                    -> amplify or attenuate the pecision of the focus
            memory_vb: [batch_size x mem_hei   x mem_wid]
        returns:
            wc_vb:     [batch_size x num_heads x mem_hei]
                    -> the attention weight by content focus
        """
        K_vb = batch_cosine_sim(self.key_vb, memory_vb)  # [batch_size x num_heads x mem_hei]
        self.wc_vb = K_vb * self.beta_vb.expand_as(K_vb) # [batch_size x num_heads x mem_hei]
        self.wc_vb = F.softmax(self.wc_vb.transpose(0, 2)).transpose(0, 2)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号