def plot_predictions_over_data(X, Y, mdl, saveplot = False, ax = None, datalim = 1000):
do_hexbin = False
if X.shape[0] > 4000:
do_hexbin = False # True
X = X[-4000:]
Y = Y[-4000:]
# plot prediction
idim = X.shape[1]
odim = Y.shape[1]
numsamples = 1 # 2
Y_samples = []
for i in range(numsamples):
Y_samples.append(mdl.predict(X))
# print("Y_samples[0]", Y_samples[0])
fig = pl.figure()
fig.suptitle("Predictions over data xy (numsamples = %d, (%s)" % (numsamples, mdl.__class__.__name__))
gs = gridspec.GridSpec(odim, 1)
for i in range(odim):
ax = fig.add_subplot(gs[i])
target = Y[:,i]
if do_hexbin:
ax.hexbin(X, Y, gridsize = 20, alpha=1.0, cmap=pl.get_cmap("gray"))
else:
ax.plot(X, target, "k.", label="Y_", alpha=0.5)
for j in range(numsamples):
prediction = Y_samples[j][:,i]
# print("X", X.shape, "prediction", prediction.shape)
# print("X", X, "prediction", prediction)
if do_hexbin:
ax.hexbin(X[:,i], prediction, gridsize = 30, alpha=0.6, cmap=pl.get_cmap("Reds"))
else:
ax.plot(X[:,i], prediction, "r.", label="Y_", alpha=0.25)
# get limits
xlim = ax.get_xlim()
ylim = ax.get_ylim()
error = target - prediction
mse = np.mean(np.square(error))
mae = np.mean(np.abs(error))
xran = xlim[1] - xlim[0]
yran = ylim[1] - ylim[0]
ax.text(xlim[0] + xran * 0.1, ylim[0] + yran * 0.3, "mse = %f" % mse)
ax.text(xlim[0] + xran * 0.1, ylim[0] + yran * 0.5, "mae = %f" % mae)
if saveplot:
filename = "plot_predictions_over_data_%s.jpg" % (mdl.__class__.__name__,)
savefig(fig, filename)
fig.show()
评论列表
文章目录