def plot(self):
fig = plt.figure()
ax = Axes3D(fig)
ax.plot_wireframe(self.meshgrid[0], self.meshgrid[1],
self.mu.reshape(self.meshgrid[0].shape), alpha=0.5, color='g')
ax.plot_wireframe(self.meshgrid[0], self.meshgrid[1],
self.environment.sample(self.meshgrid), alpha=0.5, color='b')
ax.scatter([x[0] for x in self.X], [x[1] for x in self.X], self.T, c='r',
marker='o', alpha=1.0)
plt.savefig('fig_%02d.png' % len(self.X))
评论列表
文章目录