def loss(logits, labels):
labels = tf.cast(labels, tf.int64)
batch_size = logits.get_shape()[0].value
weights = tf.constant(batch_size*[H_FACTOR, T_FACTOR], tf.float32,
shape=logits.get_shape())
softmax = tf.nn.softmax(logits)
softmax = tf.clip_by_value(softmax, 1e-10, 0.999999)
with tf.device('/cpu:0'):
targets = tf.one_hot(labels, depth=2)
cross_entropy = -tf.reduce_mean(weights*targets*tf.log(softmax) +
weights*(1-targets)*tf.log(1-softmax),
reduction_indices=[1])
cross_entropy_mean = tf.reduce_mean(cross_entropy, name='cross_entropy')
tf.add_to_collection('losses', cross_entropy_mean)
return tf.add_n(tf.get_collection('losses'), name='total_loss')
评论列表
文章目录