memory.py 文件源码

python
阅读 40 收藏 0 点赞 0 评论 0

项目:tefla 作者: openAGI 项目源码 文件源码
def get_hash_slots(self, query):
        """Gets hashed-to buckets for batch of queries.

        Args:
          query: 2-d Tensor of query vectors.

        Returns:
          A list of hashed-to buckets for each hash function.
        """

        binary_hash = [
            tf.less(tf.matmul(query, self.hash_vecs[i], transpose_b=True), 0)
            for i in xrange(self.num_libraries)]
        hash_slot_idxs = [
            tf.reduce_sum(
                tf.to_int32(binary_hash[i]) *
                tf.constant([[2 ** i for i in xrange(self.num_hashes)]],
                            dtype=tf.int32), 1)
            for i in xrange(self.num_libraries)]
        return hash_slot_idxs
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号