plotter.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:AIclass 作者: mttk 项目源码 文件源码
def plot_surface_3d(X, y_actual, NN):
    fig = plt.figure()
    plt.title("Predicted function with marked training samples")
    ax = Axes3D(fig)

    size = X.shape[0]

    ax.view_init(elev=30, azim=70)
    scatter_actual = ax.scatter(X[:,0], X[:,1], y_actual, c='g', depthshade=False)

    x0s = sorted(X[:,0])
    x1s = sorted(X[:,1])

    x0s, x1s = np.meshgrid(x0s, x1s)
    predicted_surface = np.zeros((size, size))

    for i in range(size):
        for j in range(size):
            predicted_surface[i,j] = NN.output(np.array([x0s[i,j], x1s[i,j]]))

    surf = ax.plot_surface(x0s, x1s, predicted_surface, rstride=2, cstride=2, linewidth=0, cmap=cm.coolwarm, alpha=0.5)

    plt.grid()
    plt.show()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号