def plotModel3D(vectorFile, numClusters):
# http://scikit-learn.org/stable/auto_examples/cluster/plot_cluster_iris.html
model = Doc2Vec.load("Models\\" + vectorFile)
docVecs = model.docvecs.doctag_syn0
reduced_data = PCA(n_components=10).fit_transform(docVecs)
kmeans = KMeans(init='k-means++', n_clusters=numClusters, n_init=10)
fig = plt.figure(1, figsize=(10, 10))
ax = Axes3D(fig, rect=[0, 0, .95, 1], elev=48, azim=134)
kmeans.fit(reduced_data)
labels = kmeans.labels_
ax.scatter(reduced_data[:, 5], reduced_data[:, 2], reduced_data[:, 3], c=labels.astype(np.float))
ax.w_xaxis.set_ticklabels([])
ax.w_yaxis.set_ticklabels([])
ax.w_zaxis.set_ticklabels([])
# Plot the ground truth
fig = plt.figure(1, figsize=(10, 10))
plt.clf()
ax = Axes3D(fig, rect=[0, 0, .95, 1], elev=48, azim=134)
plt.cla()
ax.scatter(reduced_data[:, 5], reduced_data[:, 2], reduced_data[:, 3], c=labels.astype(np.float))
ax.w_xaxis.set_ticklabels([])
ax.w_yaxis.set_ticklabels([])
ax.w_zaxis.set_ticklabels([])
plt.show()
评论列表
文章目录