def rmse(Y_true, Y_pred):
# https://www.kaggle.com/wiki/RootMeanSquaredError
from sklearn.metrics import mean_squared_error
print('shape:', Y_true.shape, Y_pred.shape)
print("===RMSE===")
# in
RMSE = mean_squared_error(Y_true[:, 0].flatten(), Y_pred[:, 0].flatten())**0.5
print('inflow: ', RMSE)
# out
if Y_true.shape[1] > 1:
RMSE = mean_squared_error(Y_true[:, 1].flatten(), Y_pred[:, 1].flatten())**0.5
print('outflow: ', RMSE)
# new
if Y_true.shape[1] > 2:
RMSE = mean_squared_error(Y_true[:, 2].flatten(), Y_pred[:, 2].flatten())**0.5
print('newflow: ', RMSE)
# end
if Y_true.shape[1] > 3:
RMSE = mean_squared_error(Y_true[:, 3].flatten(), Y_pred[:, 3].flatten())**0.5
print('endflow: ', RMSE)
RMSE = mean_squared_error(Y_true.flatten(), Y_pred.flatten())**0.5
print("total rmse: ", RMSE)
print("===RMSE===")
return RMSE
评论列表
文章目录