train.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:CIKM2017 作者: heliarmk 项目源码 文件源码
def combind_loss(logits, labels, reg_preds, reg_labels):
    alpha = 1
    beta = 0.025
    labels = tf.cast(labels, tf.int64)
    cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
        labels=labels, logits=logits, name='cross_entropy_per_example')
    cem = tf.reduce_mean(cross_entropy, name='cross_entropy')
    w_cem = cem * alpha
    tf.add_to_collection("losses", w_cem)
    reg_labels = tf.reshape(reg_labels, (-1, 1))
    # rmse = tf.sqrt(tf.losses.mean_squared_error(reg_labels, reg_preds, loss_collection=None))
    rmse = tf.sqrt(tf.reduce_mean(tf.squared_difference(reg_labels, reg_preds)))
    w_rmse = rmse * beta
    tf.add_to_collection("losses", w_rmse)

    return tf.add_n(tf.get_collection("losses"), name='combinded_loss'), cem, rmse
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号