train.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:neuroimage-tensorflow 作者: corticometrics 项目源码 文件源码
def loss_fn(logits, labels):        
    # input:  logits: Logits tensor, float - [batch_size, 256, 256, 256, 2].
    # intput: labels: Labels tensor, int8 - [batch_size, 256, 256, 256].
    # output: loss: Loss tensor of type float.

    labels = tf.to_int64(labels)
    print_tensor_shape( logits, 'logits shape ')
    print_tensor_shape( labels, 'labels shape ')

    # reshape to match args required for the cross entropy function
    logits_re = tf.reshape( logits, [-1, 2] )
    labels_re = tf.reshape( labels, [-1] )
    #print_tensor_shape( logits_re, 'logits shape after')
    #print_tensor_shape( labels_re, 'labels shape after')

    # call cross entropy with logits
    cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels, name='cross_entropy')
    print_tensor_shape( cross_entropy, 'cross_entropy shape ')

    loss = tf.reduce_mean(cross_entropy, name='1cnn_cross_entropy_mean')
    print_tensor_shape( loss, 'loss shape ')

    return loss
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号