losses.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:youtube-8m 作者: wangheda 项目源码 文件源码
def calculate_loss_mix(self, predictions, predictions_class, labels, **unused_params):
    with tf.name_scope("loss_mix"):
      float_labels = tf.cast(labels, tf.float32)
      if FLAGS.support_type=="class":
        seq = np.loadtxt(FLAGS.class_file)
        tf_seq = tf.one_hot(tf.constant(seq,dtype=tf.int32),FLAGS.encoder_size)
        float_classes_org = tf.matmul(float_labels,tf_seq)
        class_true = tf.ones(tf.shape(float_classes_org))
        class_false = tf.zeros(tf.shape(float_classes_org))
        float_classes = tf.where(tf.greater(float_classes_org, class_false), class_true, class_false)
        cross_entropy_class = self.calculate_loss(predictions_class,float_classes)
      elif FLAGS.support_type=="frequent":
        float_classes = float_labels[:,0:FLAGS.encoder_size]
        cross_entropy_class = self.calculate_loss(predictions_class,float_classes)
      elif FLAGS.support_type=="encoder":
        float_classes = float_labels
        for i in range(FLAGS.encoder_layers):
          var_i = np.loadtxt(FLAGS.autoencoder_dir+'autoencoder_layer%d.model' % i)
          weight_i = tf.constant(var_i[:-1,:],dtype=tf.float32)
          bias_i = tf.reshape(tf.constant(var_i[-1,:],dtype=tf.float32),[-1])
          float_classes = tf.nn.xw_plus_b(float_classes,weight_i,bias_i)
          if i<FLAGS.encoder_layers-1:
            float_classes = tf.nn.relu(float_classes)
          else:
            float_classes = tf.nn.sigmoid(float_classes)
            #float_classes = tf.nn.relu(tf.sign(float_classes - 0.5))
        cross_entropy_class = self.calculate_mseloss(predictions_class,float_classes)
      else:
        float_classes = float_labels
        for i in range(FLAGS.moe_layers-1):
          float_classes = tf.concat((float_classes,float_labels),axis=1)
        cross_entropy_class = self.calculate_loss(predictions_class,float_classes)
      cross_entropy_loss = self.calculate_loss(predictions,labels)
      return cross_entropy_loss + 0.1*cross_entropy_class
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号