def get_knn_score_for(tree, k=5):
tree = tree_copy_with_start(tree)
tree_encoding = encoder.get_encoding([None, tree]) # This makes sure that token-based things fail
tree_str_rep = str(tree)
distances = cdist(np.atleast_2d(tree_encoding), encodings, 'cosine')
knns = np.argsort(distances)[0]
num_non_identical_nns = 0
sum_equiv_nns = 0
current_i = 0
while num_non_identical_nns < k and current_i < len(knns) and eq_class_counts[
tree.symbol] - 1 > num_non_identical_nns:
expression_idx = knns[current_i]
current_i += 1
if eq_class_idx_to_names[expression_data[expression_idx]['eq_class']] == tree.symbol and str(
expression_data[expression_idx]['tree']) == tree_str_rep:
continue # This is an identical expression, move on
num_non_identical_nns += 1
if eq_class_idx_to_names[expression_data[expression_idx]['eq_class']] == tree.symbol:
sum_equiv_nns += 1
return "(%s-nn-stat: %s)" % (k, sum_equiv_nns / k)
评论列表
文章目录