def plot_x_x_yhat(x, x_hat):
"""Plot x, y and y_hat side by side."""
plt.close("all")
f = plt.figure() # figsize=(15, 10.8), dpi=300
gs = gridspec.GridSpec(1, 2)
ims = [x, x_hat]
tils = [
"xin:" + str(x.shape[0]) + "x" + str(x.shape[1]),
"xout:" + str(x.shape[1]) + "x" + str(x_hat.shape[1])]
for n, ti in zip([0, 1], tils):
f.add_subplot(gs[n])
plt.imshow(ims[n], cmap=cm.Greys_r)
plt.title(ti)
ax = f.gca()
ax.set_axis_off()
return f
评论列表
文章目录