def plot_surface_3d(X, y_actual, NN):
fig = plt.figure()
plt.title("Predicted function with marked training samples")
ax = Axes3D(fig)
size = X.shape[0]
ax.view_init(elev=30, azim=70)
scatter_actual = ax.scatter(X[:,0], X[:,1], y_actual, c='g', depthshade=False)
x0s = sorted(X[:,0])
x1s = sorted(X[:,1])
x0s, x1s = np.meshgrid(x0s, x1s)
predicted_surface = np.zeros((size, size))
for i in range(size):
for j in range(size):
predicted_surface[i,j] = NN.output(np.array([x0s[i,j], x1s[i,j]]))
surf = ax.plot_surface(x0s, x1s, predicted_surface, rstride=2, cstride=2, linewidth=0, cmap=cm.coolwarm, alpha=0.5)
plt.grid()
plt.show()
评论列表
文章目录