def plot_3d(objective_function, length=20):
"""
Plot 3D functions
:param objective_function:
:type objective_function:
:param length:
:type length:
:return:
:rtype:
"""
bounds = objective_function.get_bounds()
if len(bounds) != 2:
return
x_grid = np.linspace(bounds[0][0], bounds[0][1], length)
y_grid = np.linspace(bounds[1][0], bounds[1][1], length)
x_grid, y_grid = np.meshgrid(x_grid, y_grid)
grid = np.vstack((x_grid.flatten(), y_grid.flatten())).T
z_points = objective_function.evaluate(grid)
z_points = z_points.reshape(length, length)
fig = pyplot.figure()
axis = fig.gca(projection='3d')
surf = axis.plot_surface(x_grid, y_grid,
z_points, rstride=1, cstride=1,
cmap=cm.cool, linewidth=0, antialiased=False,
alpha=0.3)
axis.contour(x_grid.tolist(), y_grid.tolist(), z_points.tolist(),
zdir='z', offset=z_points.min(), cmap=cm.cool)
axis.set_xlim(bounds[0][0], bounds[0][1])
axis.set_ylim(bounds[1][0], bounds[1][1])
pyplot.title(objective_function.__class__.__name__)
axis.zaxis.set_major_locator(LinearLocator(10))
axis.zaxis.set_major_formatter(FormatStrFormatter('%.02f'))
fig.colorbar(surf, shrink=0.5, aspect=5)
pyplot.show()
评论列表
文章目录