glove_factory.py 文件源码

python
阅读 17 收藏 0 点赞 0 评论 0

项目:Msc_Multi_label_ZeroShot 作者: thomasSve 项目源码 文件源码
def build_tree(self, space, words, imdb_name, num_trees = 1000, vector_list = None):
        """
        Build annoy tree to calculate distance, if vector_list = None builds full tree with all words in language model. If not None, builds using the words in the vector list.
        """
        # If pckl exist, load, else build
        tree_path = osp.join(self._devkit_path, self.name + '_' + imdb_name + str(space) + '.ann')
        pckl_path = osp.join(self._devkit_path, self.name + '_' + imdb_name + str(space) + 'array'+".pkl")
        t = AnnoyIndex(self._dimension, metric="euclidean")
        if osp.exists(tree_path):
            print "Tree exist, loading from file..."

            t.load(tree_path)
            self._tree = t
            with open(pckl_path,  'rb') as file:
                self._labels = pickle.load(file)
        else:
            print "Building tree..."

            counter = 0
            word_list = []
            if space == 0:
                for word, feature in self._vectors.iteritems():
                    word_list.append(word)
                    t.add_item(counter,feature)
                    counter += 1
            else:
                for w in words:
                    word_list.append(w)
                    t.add_item(counter, self.word_vector(w))
                    counter += 1

            t.build(num_trees)
            self._tree = t
            self._labels = word_list

            # Save tree
            t.save(tree_path)
            with open(pckl_path, 'wb') as handle:
                pickle.dump(word_list,handle)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号