train.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:youtube-8m 作者: wangheda 项目源码 文件源码
def get_weights_by_predictions(labels_batch, predictions):
  epsilon = 1e-6
  float_labels = tf.cast(labels_batch, dtype=tf.float32)
  cross_entropy_loss = float_labels * tf.log(predictions + epsilon) + (
      1 - float_labels) * tf.log(1 - predictions + epsilon)
  ce = tf.reduce_sum(tf.negative(cross_entropy_loss), axis=1)
  mean_ce = tf.reduce_mean(ce + epsilon)
  weights = tf.where(ce > mean_ce, 
                     3.0 * tf.ones_like(ce),
                     0.5 * tf.ones_like(ce))
  return weights
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号