def save_images(self, X, imgfile, density=False):
ax = plt.axes()
x = X[:, 0]
y = X[:, 1]
if density:
xy = np.vstack([x,y])
z = scipy.stats.gaussian_kde(xy)(xy)
ax.scatter(x, y, c=z, marker='o', edgecolor='')
else:
ax.scatter(x, y, marker='o', c=range(x.shape[0]),
cmap=plt.cm.coolwarm)
if self.collection is not None:
self.collection.set_transform(ax.transData)
ax.add_collection(self.collection)
ax.text(x[0], y[0], str('start'), transform=ax.transAxes)
ax.axis([-0.2, 1.2, -0.2, 1.2])
fig = plt.gcf()
plt.savefig(imgfile)
plt.close()
评论列表
文章目录