resdeconv_model.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:traffic_video_analysis 作者: polltooh 项目源码 文件源码
def loss(infer, count_diff_infer, label):
    l2_loss = tf.reduce_mean(tf.reduce_sum(tf.square(infer - label), [1,2,3]), name = 'l2_loss')
    #l2_loss = mf.huber_loss(tf.reduce_sum(infer, [1,2,3]), tf.reduce_sum(label, [1,2,3]), huber_epsilon, 'density_loss')

    huber_epsilon = 5.0
    c_lambda = 0.1
    count_infer = tf.add(tf.squeeze(count_diff_infer), tf.reduce_sum(infer, [1,2,3]), name = "count_infer")
    count_loss = tf.mul(c_lambda, mf.huber_loss(count_infer, tf.reduce_sum(label, [1,2,3]), huber_epsilon, 'huber_loss'),
                name = 'count_loss')
    #count_loss = tf.mul(c_lambda, tf.reduce_mean(tf.square(count_infer - tf.reduce_sum(label, [1,2,3]))),
                    #name = 'count_loss')

    tf.add_to_collection('losses', count_loss)
    tf.add_to_collection('losses', l2_loss)

    return tf.add_n(tf.get_collection('losses'), name = 'total_loss'), count_infer
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号