models_actinf.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:smp_base 作者: x75 项目源码 文件源码
def plot_nodes_over_data_scattermatrix_hexbin(X, Y, mdl, predictions, distances, activities, saveplot = False):
    """models_actinf.plot_nodes_over_data_scattermatrix_hexbin

    Plot models nodes (if applicable) over the hexbinned data
    expanding dimensions as a scattermatrix.
    """

    idim = X.shape[1]
    odim = Y.shape[1]
    numplots = idim * odim + 2
    fig = pl.figure()
    fig.suptitle("Predictions over data xy scattermatrix/hexbin (%s)" % (mdl.__class__.__name__))
    gs = gridspec.GridSpec(idim, odim)
    figaxes = []
    for i in range(idim):
        figaxes.append([])
        for o in range(odim):
            figaxes[i].append(fig.add_subplot(gs[i,o]))
    err = 0

    # colsa = ["k", "r", "g", "c", "m", "y"]
    # colsb = ["k", "r", "g", "c", "m", "y"]
    colsa = ["k" for col in range(idim)]
    colsb = ["r" for col in range(odim)]
    for i in range(odim): # odim * 2
        for j in range(idim):
            # pl.subplot(numplots, 1, (i*idim)+j+1)
            ax = figaxes[j][i]
            # target = Y[h,i]
            # X__ = X_[j] # X[h,j]
            # err += np.sum(np.square(target - prediction))
            # ax.plot(X__, [target], colsa[j] + ".", alpha=0.25, label="target_%d" % i)
            # ax.plot(X__, [prediction[0,i]], colsb[j] + "o", alpha=0.25, label="pred_%d" % i)
            # ax.plot(X[:,j], Y[:,i], colsa[j] + ".", alpha=0.25, label="target_%d" % i)
            ax.hexbin(X[:,j], Y[:,i], gridsize = 20, alpha=0.75, cmap=pl.get_cmap("gray"))
            ax.plot(X[:,j], predictions[:,i], colsb[j] + "o", alpha=0.15, label="pred_%d" % i, markersize=8)
            # pred1 = mdl.filter_e.neuron(mdl.filter_e.flat_to_coords(activities_sorted[-1]))
            # ax.plot(X__, [pred1], "ro", alpha=0.5)
            # pred2 = mdl.filter_e.neuron(mdl.filter_e.flat_to_coords(activities_sorted[-2]))
            # ax.plot(X__, [pred2], "ro", alpha=0.25)
    # print("accum total err = %f" % (err / X.shape[0] / (idim * odim)))
    if saveplot:
        filename = "plot_nodes_over_data_scattermatrix_hexbin_%s.jpg" % (mdl.__class__.__name__,)
        savefig(fig, filename)
    fig.show()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号