classifier.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:Clustering 作者: Ram81 项目源码 文件源码
def plot3D(data, output_labels_3d, centroids):
    '''
        Creating a 3d Plot of the dataset
    ''' 
    fig = plt.figure(3)
    ax = Axes3D(fig)

    for i in range(len(output_labels_3d)):
        if output_labels_3d[i] == 0:
            ax.scatter(data[i, 0], data[i, 1], data[i, 2], s = 20, c = 'k')
        elif output_labels_3d[i] == 1:
            ax.scatter(data[i, 0], data[i, 1], data[i, 2], s = 20, c = 'r')
        elif output_labels_3d[i] == 2:
            ax.scatter(data[i, 0], data[i, 1], data[i, 2], s = 20, c = 'b')
        elif output_labels_3d[i] == 3:
            ax.scatter(data[i, 0], data[i, 1], data[i, 2], s = 20, c = 'c')
        elif output_labels_3d[i] == 4:
            ax.scatter(data[i, 0], data[i, 1], data[i, 2], s = 20, c = 'g')
        elif output_labels_3d[i] == 5:
            ax.scatter(data[i, 0], data[i, 1], data[i, 2], s = 20, c = 'y')
        elif output_labels_3d[i] == 6:
            ax.scatter(data[i, 0], data[i, 1], data[i, 2], s = 20, c = 'm')
        elif output_labels_3d[i] == 7:
            ax.scatter(data[i, 0], data[i, 1], data[i, 2], s = 25, c = 'y')
        elif output_labels_3d[i] == 8:
            ax.scatter(data[i, 0], data[i, 1], data[i, 2], s = 25, c = 'b')
        elif output_labels_3d[i] == 9:
            ax.scatter(data[i, 0], data[i, 1], data[i, 2], s = 25, c = 'k')
        elif output_labels_3d[i] == 10:
            ax.scatter(data[i, 0], data[i, 1], data[i, 2], s = 25, c = 'm')
        elif output_labels_3d[i] == 11:
            ax.scatter(data[i, 0], data[i, 1], data[i, 2], s = 25, c = 'g')

    ax.scatter(centroids[:, 0], centroids[:, 1], centroids[:, 2], s = 150, c = 'r', marker = 'x', linewidth = 5)

    plt.show()

    return
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号