def plot3D(data, output_labels_3d, centroids):
'''
Creating a 3d Plot of the dataset
'''
fig = plt.figure(3)
ax = Axes3D(fig)
for i in range(len(output_labels_3d)):
if output_labels_3d[i] == 0:
ax.scatter(data[i, 0], data[i, 1], data[i, 2], s = 20, c = 'k')
elif output_labels_3d[i] == 1:
ax.scatter(data[i, 0], data[i, 1], data[i, 2], s = 20, c = 'r')
elif output_labels_3d[i] == 2:
ax.scatter(data[i, 0], data[i, 1], data[i, 2], s = 20, c = 'b')
elif output_labels_3d[i] == 3:
ax.scatter(data[i, 0], data[i, 1], data[i, 2], s = 20, c = 'c')
elif output_labels_3d[i] == 4:
ax.scatter(data[i, 0], data[i, 1], data[i, 2], s = 20, c = 'g')
elif output_labels_3d[i] == 5:
ax.scatter(data[i, 0], data[i, 1], data[i, 2], s = 20, c = 'y')
elif output_labels_3d[i] == 6:
ax.scatter(data[i, 0], data[i, 1], data[i, 2], s = 20, c = 'm')
elif output_labels_3d[i] == 7:
ax.scatter(data[i, 0], data[i, 1], data[i, 2], s = 25, c = 'y')
elif output_labels_3d[i] == 8:
ax.scatter(data[i, 0], data[i, 1], data[i, 2], s = 25, c = 'b')
elif output_labels_3d[i] == 9:
ax.scatter(data[i, 0], data[i, 1], data[i, 2], s = 25, c = 'k')
elif output_labels_3d[i] == 10:
ax.scatter(data[i, 0], data[i, 1], data[i, 2], s = 25, c = 'm')
elif output_labels_3d[i] == 11:
ax.scatter(data[i, 0], data[i, 1], data[i, 2], s = 25, c = 'g')
ax.scatter(centroids[:, 0], centroids[:, 1], centroids[:, 2], s = 150, c = 'r', marker = 'x', linewidth = 5)
plt.show()
return
评论列表
文章目录