knnvisualization.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:eqnet 作者: mast-group 项目源码 文件源码
def get_knn_score_for(tree, k=5):
        tree = tree_copy_with_start(tree)
        tree_encoding = encoder.get_encoding([None, tree])  # This makes sure that token-based things fail
        tree_str_rep = str(tree)

        distances = cdist(np.atleast_2d(tree_encoding), encodings, 'cosine')
        knns = np.argsort(distances)[0]

        num_non_identical_nns = 0
        sum_equiv_nns = 0
        current_i = 0
        while num_non_identical_nns < k and current_i < len(knns) and eq_class_counts[
            tree.symbol] - 1 > num_non_identical_nns:
            expression_idx = knns[current_i]
            current_i += 1
            if eq_class_idx_to_names[expression_data[expression_idx]['eq_class']] == tree.symbol and str(
                    expression_data[expression_idx]['tree']) == tree_str_rep:
                continue  # This is an identical expression, move on
            num_non_identical_nns += 1
            if eq_class_idx_to_names[expression_data[expression_idx]['eq_class']] == tree.symbol:
                sum_equiv_nns += 1
        return "(%s-nn-stat: %s)" % (k, sum_equiv_nns / k)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号