def create_scatter_plot(outfile_results, config):
true_vs_pred = os.path.join(config.output_dir,
config.name + "_results.csv")
true_vs_pred_plot = os.path.join(config.output_dir,
config.name + "_results.png")
with hdf.open_file(outfile_results, 'r') as f:
prediction = f.get_node("/", "Prediction").read()
y_true = f.get_node("/", "y_true").read()
np.savetxt(true_vs_pred, X=np.vstack([y_true, prediction]).T,
delimiter=',')
plt.figure()
plt.scatter(y_true, prediction)
plt.title('true vs prediction')
plt.xlabel('True')
plt.ylabel('Prediction')
plt.savefig(true_vs_pred_plot)
评论列表
文章目录