network_accuracy.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:TICC 作者: davidhallac 项目源码 文件源码
def computeF1Score_delete(num_cluster,matching_algo,actual_clusters,threshold_algo,save_matrix = False):
    """
    computes the F1 scores and returns a list of values
    """
    F1_score = np.zeros(num_cluster)
    for cluster in xrange(num_cluster):
        matched_cluster = matching_algo[cluster]
        true_matrix = actual_clusters[cluster]
        estimated_matrix = threshold_algo[matched_cluster]
        if save_matrix: np.savetxt("estimated_matrix_cluster=" + str(cluster)+".csv",estimated_matrix,delimiter = ",", fmt = "%1.4f")
        TP = 0
        TN = 0
        FP = 0
        FN = 0
        for i in xrange(num_stacked*n):
            for j in xrange(num_stacked*n):
                if estimated_matrix[i,j] == 1 and true_matrix[i,j] != 0:
                    TP += 1.0
                elif estimated_matrix[i,j] == 0 and true_matrix[i,j] == 0:
                    TN += 1.0
                elif estimated_matrix[i,j] == 1 and true_matrix[i,j] == 0:
                    FP += 1.0
                else:
                    FN += 1.0
        precision = (TP)/(TP + FP)
        recall = TP/(TP + FN)
        f1 = (2*precision*recall)/(precision + recall)
        F1_score[cluster] = f1
    return F1_score
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号