losses.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:youtube-8m 作者: wangheda 项目源码 文件源码
def calculate_loss_distill(self, predictions, labels_distill, labels, **unused_params):
    with tf.name_scope("loss_distill"):
      print("loss_distill")
      epsilon = 10e-6
      float_labels = tf.cast(labels, tf.float32)
      float_labels_distill = tf.cast(labels_distill, tf.float32)
      embedding_mat = np.loadtxt("./resources/embedding_matrix.model")
      vocab_size = embedding_mat.shape[1]
      labels_size = float_labels.get_shape().as_list()[1]
      embedding_mat = tf.cast(embedding_mat,dtype=tf.float32)
      cross_entropy_loss_1 = float_labels * tf.log(predictions + epsilon) + (
          1 - float_labels) * tf.log(1 - predictions + epsilon)
      float_labels_1 = float_labels[:,:vocab_size]
      labels_smooth = tf.matmul(float_labels_1,embedding_mat)/tf.reduce_sum(float_labels_1,axis=1,keep_dims=True)
      float_classes = labels_smooth
      for i in range(labels_size//vocab_size-1):
        float_classes = tf.concat((float_classes,labels_smooth),axis=1)
      cross_entropy_loss_2 = float_classes * tf.log(predictions + epsilon) + (
          1 - float_classes) * tf.log(1 - predictions + epsilon)
      cross_entropy_loss_3 = float_labels_distill * tf.log(predictions + epsilon) + (
          1 - float_labels_distill) * tf.log(1 - predictions + epsilon)

      cross_entropy_loss = cross_entropy_loss_1*0.5 + cross_entropy_loss_2*0.5 + cross_entropy_loss_3*0.5
      cross_entropy_loss = tf.negative(cross_entropy_loss)

      return tf.reduce_mean(tf.reduce_sum(cross_entropy_loss, 1))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号