def plot_x_y_yhat(x, y, y_hat, xsz, ysz, binz=False):
"""Plot x, y and y_hat side by side."""
plt.close("all")
f = plt.figure(figsize=(15, 10.8), dpi=300)
gs = gridspec.GridSpec(1, 3)
if binz:
y_hat = (y_hat > 0.5) * 1.
ims = [x, y, y_hat]
tils = [
"x:" + str(xsz) + "x" + str(xsz),
"y:" + str(ysz) + "x" + str(ysz),
"yhat:" + str(ysz) + "x" + str(ysz)]
for n, ti in zip([0, 1, 2], tils):
f.add_subplot(gs[n])
if n == 0:
plt.imshow(ims[n], cmap=cm.Greys_r)
else:
plt.imshow(ims[n], cmap=cm.Greys_r)
plt.title(ti)
return f
tools.py 文件源码
python
阅读 26
收藏 0
点赞 0
评论 0
评论列表
文章目录