MMTransE.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:MTransE 作者: muhaochen 项目源码 文件源码
def kNN_entity(self, vec, tgt_lan='en', topk=10, method=0, self_vec_id=None, replace_q=True):
        q = []
        model = self.models.get(tgt_lan)
        if model == None:
            print "Model for language", tgt_lan," does not exist."
            return None
        for i in range(len(model.vec_e)):
            #skip self
            if self_vec_id != None and i == self_vec_id:
                continue
            if method == 1:
                dist = SP.distance.cosine(vec, model.vec_e[i])
            else:
                dist = LA.norm(vec - model.vec_e[i])
            if (not replace_q) or len(q) < topk:
                HP.heappush(q, model.index_dist(i, dist))
            else:
                #indeed it fetches the biggest
                tmp = HP.nsmallest(1, q)[0]
                if tmp.dist > dist:
                    HP.heapreplace(q, model.index_dist(i, dist) )
        rst = []
        if replace_q:
            while len(q) > 0:
                item = HP.heappop(q)
                rst.insert(0, (model.vocab_e[model.vec2e[item.index]], item.dist))
        else:
            while len(q) > topk:
                HP.heappop(q)
            while len(q) > 0:
                item = HP.heappop(q)
                rst.insert(0, (model.vocab_e[model.vec2e[item.index]], item.dist))
        return rst

    #given entity name, find kNN
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号