hm.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:holographic_memory 作者: jramapuram 项目源码 文件源码
def decode(self, memories, keys, num_keys=None):
        keys = self._normalize(keys)
        num_memories = memories.get_shape().as_list()
        num_memories[0] = self.num_models if num_memories[0] is None else num_memories[0]
        num_keys = keys.get_shape().as_list()[0] if num_keys is None else num_keys
        print 'decode: numkeys = ', num_keys, ' | num_memories = ', num_memories

        # re-gather keys to avoid mixing between different keys.
        perms = self.perm_keys(keys, self.perms, num_keys=num_keys)
        pshp = perms.get_shape().as_list()
        pshp[0] = num_keys*self.num_models if pshp[0] is None else pshp[0]
        pshp[1] = num_memories[1] if pshp[1] is None else pshp[1]
        permed_keys = tf.concat(0, [tf.strided_slice(perms, [i, 0], pshp, [num_keys, 1])
                                    for i in range(num_keys)])
        print 'memories = ', num_memories, \
            '| dec_perms =', permed_keys.get_shape().as_list()
        return self.conv_func(memories, permed_keys,
                              num_memories[0],
                              self.num_models,
                              num_keys=num_keys*self.num_models,
                              conj=True)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号