def loss_fn(logits, labels):
# input: logits: Logits tensor, float - [batch_size, 256, 256, 256, 2].
# intput: labels: Labels tensor, int8 - [batch_size, 256, 256, 256].
# output: loss: Loss tensor of type float.
labels = tf.to_int64(labels)
print_tensor_shape( logits, 'logits shape ')
print_tensor_shape( labels, 'labels shape ')
# reshape to match args required for the cross entropy function
logits_re = tf.reshape( logits, [-1, 2] )
labels_re = tf.reshape( labels, [-1] )
#print_tensor_shape( logits_re, 'logits shape after')
#print_tensor_shape( labels_re, 'labels shape after')
# call cross entropy with logits
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels, name='cross_entropy')
print_tensor_shape( cross_entropy, 'cross_entropy shape ')
loss = tf.reduce_mean(cross_entropy, name='1cnn_cross_entropy_mean')
print_tensor_shape( loss, 'loss shape ')
return loss
评论列表
文章目录