tf_core.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:sparks 作者: ImpactHorizon 项目源码 文件源码
def loss(logits, labels):
    labels = tf.cast(labels, tf.int64)  
    batch_size = logits.get_shape()[0].value  
    weights = tf.constant(batch_size*[H_FACTOR, T_FACTOR], tf.float32, 
                            shape=logits.get_shape())
    softmax = tf.nn.softmax(logits)
    softmax = tf.clip_by_value(softmax, 1e-10, 0.999999)

    with tf.device('/cpu:0'):
        targets = tf.one_hot(labels, depth=2)

    cross_entropy = -tf.reduce_mean(weights*targets*tf.log(softmax) + 
                                        weights*(1-targets)*tf.log(1-softmax), 
                                        reduction_indices=[1])    
    cross_entropy_mean = tf.reduce_mean(cross_entropy, name='cross_entropy')
    tf.add_to_collection('losses', cross_entropy_mean)

    return tf.add_n(tf.get_collection('losses'), name='total_loss')
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号