def plot_prediction(x_test, y_test, prediction, save=False):
import matplotlib
import matplotlib.pyplot as plt
test_size = x_test.shape[0]
fig, ax = plt.subplots(test_size, 3, figsize=(12,12), sharey=True, sharex=True)
x_test = crop_to_shape(x_test, prediction.shape)
y_test = crop_to_shape(y_test, prediction.shape)
ax = np.atleast_2d(ax)
for i in range(test_size):
cax = ax[i, 0].imshow(x_test[i])
plt.colorbar(cax, ax=ax[i,0])
cax = ax[i, 1].imshow(y_test[i, ..., 1])
plt.colorbar(cax, ax=ax[i,1])
pred = prediction[i, ..., 1]
pred -= np.amin(pred)
pred /= np.amax(pred)
cax = ax[i, 2].imshow(pred)
plt.colorbar(cax, ax=ax[i,2])
if i==0:
ax[i, 0].set_title("x")
ax[i, 1].set_title("y")
ax[i, 2].set_title("pred")
fig.tight_layout()
if save:
fig.savefig(save)
else:
fig.show()
plt.show()
评论列表
文章目录