losses.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:youtube-8m 作者: wangheda 项目源码 文件源码
def calculate_loss_mix(self, predictions, predictions_class, labels, **unused_params):
    with tf.name_scope("loss_softmax_mix"):
      vocab_size = labels.get_shape().as_list()[1]
      cross_entropy_class = tf.constant(0.0)
      for i in range(FLAGS.moe_layers):
        predictions_subclass = predictions_class[:,i*vocab_size:(i+1)*vocab_size]
        cross_entropy_class = cross_entropy_class + self.calculate_loss(predictions_subclass,labels)
      cross_entropy_loss = self.calculate_loss(predictions,labels)
      return cross_entropy_loss + 0.1*cross_entropy_class
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号