losses.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:youtube-8m 作者: wangheda 项目源码 文件源码
def get_support(self, labels, support_type=None):
    if support_type == None:
      support_type = FLAGS.support_type
    if "," in support_type:
      new_labels = []
      for st in support_type.split(","):
        new_labels.append(tf.cast(self.get_support(labels, st), dtype=tf.float32))
      support_labels = tf.concat(new_labels, axis=1)
      return support_labels
    elif support_type == "vertical":
      num_classes = FLAGS.num_classes
      num_verticals = FLAGS.num_verticals
      vertical_file = FLAGS.vertical_file
      vertical_mapping = np.zeros([num_classes, num_verticals], dtype=np.float32)
      float_labels = tf.cast(labels, dtype=tf.float32)
      with open(vertical_file) as F:
        for line in F:
          group = map(int, line.strip().split())
          if len(group) == 2:
            x, y = group
            vertical_mapping[x, y] = 1
      vm_init = tf.constant_initializer(vertical_mapping)
      vm = tf.get_variable("vm", shape = [num_classes, num_verticals], 
                           trainable=False, initializer=vm_init)
      vertical_labels = tf.matmul(float_labels, vm)
      return tf.cast(vertical_labels > 0.2, tf.float32)
    elif support_type == "frequent":
      num_frequents = FLAGS.num_frequents
      frequent_labels = tf.slice(labels, begin=[0, 0], size=[-1, num_frequents])
      frequent_labels = tf.cast(frequent_labels, dtype=tf.float32)
      return frequent_labels
    elif support_type == "label":
      float_labels = tf.cast(labels, dtype=tf.float32)
      return float_labels
    else:
      raise NotImplementedError()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号