def combind_loss(logits, labels, reg_preds, reg_labels):
alpha = 1
beta = 0.025
labels = tf.cast(labels, tf.int64)
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
labels=labels, logits=logits, name='cross_entropy_per_example')
cem = tf.reduce_mean(cross_entropy, name='cross_entropy')
w_cem = cem * alpha
tf.add_to_collection("losses", w_cem)
reg_labels = tf.reshape(reg_labels, (-1, 1))
# rmse = tf.sqrt(tf.losses.mean_squared_error(reg_labels, reg_preds, loss_collection=None))
rmse = tf.sqrt(tf.reduce_mean(tf.squared_difference(reg_labels, reg_preds)))
w_rmse = rmse * beta
tf.add_to_collection("losses", w_rmse)
return tf.add_n(tf.get_collection("losses"), name='combinded_loss'), cem, rmse
评论列表
文章目录