def get_weights_by_predictions(labels_batch, predictions):
epsilon = 1e-6
float_labels = tf.cast(labels_batch, dtype=tf.float32)
cross_entropy_loss = float_labels * tf.log(predictions + epsilon) + (
1 - float_labels) * tf.log(1 - predictions + epsilon)
ce = tf.reduce_sum(tf.negative(cross_entropy_loss), axis=1)
mean_ce = tf.reduce_mean(ce + epsilon)
weights = tf.where(ce > mean_ce,
3.0 * tf.ones_like(ce),
0.5 * tf.ones_like(ce))
return weights
评论列表
文章目录