eval_functions.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:sat-seg 作者: mshiv 项目源码 文件源码
def loss(logits, labels, num_classes):

    logits = tf.reshape(logits, [-1, num_classes])
    #epsilon = tf.constant(value=1e-4)
    #logits = logits + epsilon

    #CHANGE LABELS TYPE to INT, for sparse_softmax_Cross_...
    # to FLOAT, for softmax_Cross_entropy...
    #labels = tf.to_float(tf.reshape(labels, [-1]))
    labels = tf.to_int64(tf.reshape(labels, [-1]))
    #print (np.unique(labels))
    print ('shape of logits: %s' % str(logits.get_shape()))
    print ('shape of labels: %s' % str(labels.get_shape()))

    cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits, labels, name='Cross_Entropy')
    cross_entropy_mean = tf.reduce_mean(cross_entropy, name='xentropy_mean')
    tf.add_to_collection('losses', cross_entropy_mean)

    loss = tf.add_n(tf.get_collection('losses'), name='total_loss')
    #loss = tf.add_n(cross_entropy)
    return loss
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号