plot.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:simple-linear-regression 作者: williamd4112 项目源码 文件源码
def plot_3d(model, phi, x_min, x_max, y_min, y_max, z_min, z_max, filename=None):
    fig = plt.figure()
    ax = fig.gca(projection='3d')

    X = np.arange(x_min, x_max, 5)
    Y = np.arange(y_min, y_max, 5)
    X, Y = np.meshgrid(X, Y)

    x, y = np.reshape(X, len(X)**2), np.reshape(Y, len(Y)**2) 
    Z = model(np.matrix(phi(np.array([x, y], dtype=np.float32).T)))

    Z = np.reshape(Z, [len(X), len(X)])    

    # Plot the surface.
    surf = ax.plot_surface(X, Y, Z, cmap=cm.coolwarm,
                           linewidth=0, antialiased=False, shade=True)

    # Customize the z axis.
    ax.set_zlim(z_min, z_max)
    ax.zaxis.set_major_locator(LinearLocator(10))
    ax.zaxis.set_major_formatter(FormatStrFormatter('%.02f'))

    # Add a color bar which maps values to colors.
    fig.colorbar(surf, shrink=0.5, aspect=5)

    plt.show()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号