models_actinf.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:smp_base 作者: x75 项目源码 文件源码
def plot_predictions_over_data(X, Y, mdl, saveplot = False, ax = None, datalim = 1000):
    do_hexbin = False
    if X.shape[0] > 4000:
        do_hexbin = False # True
        X = X[-4000:]
        Y = Y[-4000:]
    # plot prediction
    idim = X.shape[1]
    odim = Y.shape[1]
    numsamples = 1 # 2
    Y_samples = []
    for i in range(numsamples):
        Y_samples.append(mdl.predict(X))
    # print("Y_samples[0]", Y_samples[0])

    fig = pl.figure()
    fig.suptitle("Predictions over data xy (numsamples = %d, (%s)" % (numsamples, mdl.__class__.__name__))
    gs = gridspec.GridSpec(odim, 1)

    for i in range(odim):
        ax = fig.add_subplot(gs[i])
        target     = Y[:,i]

        if do_hexbin:
            ax.hexbin(X, Y, gridsize = 20, alpha=1.0, cmap=pl.get_cmap("gray"))
        else:
            ax.plot(X, target, "k.", label="Y_", alpha=0.5)
        for j in range(numsamples):
            prediction = Y_samples[j][:,i]
            # print("X", X.shape, "prediction", prediction.shape)
            # print("X", X, "prediction", prediction)
            if do_hexbin:
                ax.hexbin(X[:,i], prediction, gridsize = 30, alpha=0.6, cmap=pl.get_cmap("Reds"))
            else:
                ax.plot(X[:,i], prediction, "r.", label="Y_", alpha=0.25)

        # get limits
        xlim = ax.get_xlim()
        ylim = ax.get_ylim()
        error = target - prediction
        mse   = np.mean(np.square(error))
        mae   = np.mean(np.abs(error))
        xran = xlim[1] - xlim[0]
        yran = ylim[1] - ylim[0]
        ax.text(xlim[0] + xran * 0.1, ylim[0] + yran * 0.3, "mse = %f" % mse)
        ax.text(xlim[0] + xran * 0.1, ylim[0] + yran * 0.5, "mae = %f" % mae)

    if saveplot:
        filename = "plot_predictions_over_data_%s.jpg" % (mdl.__class__.__name__,)
        savefig(fig, filename)

    fig.show()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号