def plot_2d(self, f, a, b, grid_size=200):
grid_x = np.linspace(a[0], b[0], num=grid_size).reshape((-1, 1))
grid_y = np.linspace(a[1], b[1], num=grid_size).reshape((-1, 1))
x, y = np.meshgrid(grid_x, grid_y)
merged = np.stack([x.flatten(), y.flatten()])
z = f(merged).reshape(x.shape)
swap = np.swapaxes(merged, 0, 1)
mu, sigma = self.utility.mean_and_std(swap)
mu = mu.reshape(x.shape)
sigma = sigma.reshape(x.shape)
points = np.asarray(self.points)
xs = points[:, 0]
ys = points[:, 1]
zs = f(np.swapaxes(points, 0, 1))
fig = plt.figure()
ax = Axes3D(fig)
ax.plot_surface(x, y, z, color='black', label='f', alpha=0.7,
linewidth=0, antialiased=False)
ax.plot_surface(x, y, mu, color='red', label='mu', alpha=0.5)
ax.plot_surface(x, y, mu + sigma, color='blue', label='mu+sigma', alpha=0.3)
ax.plot_surface(x, y, mu - sigma, color='blue', alpha=0.3)
ax.scatter(xs, ys, zs, color='red', marker='o', s=100)
# plt.legend()
plt.show()
评论列表
文章目录