def regression_loss(reg_preds, reg_labels):
rmse = tf.sqrt(tf.reduce_mean(tf.squared_difference(reg_labels, reg_preds)))
tf.add_to_collection('losses', rmse)
return tf.add_n(tf.get_collection('losses'), name="total_loss")
评论列表
文章目录