def plot_3d(X, y_actual, y_predicted=None):
fig = plt.figure()
if y_predicted is None:
plt.title("Predicted vs actual function values")
else:
plt.title("Approximated function samples")
ax = Axes3D(fig)
ax.view_init(elev=30, azim=70)
scatter_actual = ax.scatter(X[:,0], X[:,1], y_actual, c='g', depthshade=False)
if not y_predicted is None:
scatter_predicted = ax.scatter(X[:,0], X[:,1], y_predicted, c='b', depthshade=False)
if y_predicted is None:
plt.legend((scatter_actual, scatter_predicted),
('Actual values', 'Predicted values'),
scatterpoints = 1)
plt.grid()
plt.show()
评论列表
文章目录