def get_support(self, labels, support_type=None):
if support_type == None:
support_type = FLAGS.support_type
if "," in support_type:
new_labels = []
for st in support_type.split(","):
new_labels.append(tf.cast(self.get_support(labels, st), dtype=tf.float32))
support_labels = tf.concat(new_labels, axis=1)
return support_labels
elif support_type == "vertical":
num_classes = FLAGS.num_classes
num_verticals = FLAGS.num_verticals
vertical_file = FLAGS.vertical_file
vertical_mapping = np.zeros([num_classes, num_verticals], dtype=np.float32)
float_labels = tf.cast(labels, dtype=tf.float32)
with open(vertical_file) as F:
for line in F:
group = map(int, line.strip().split())
if len(group) == 2:
x, y = group
vertical_mapping[x, y] = 1
vm_init = tf.constant_initializer(vertical_mapping)
vm = tf.get_variable("vm", shape = [num_classes, num_verticals],
trainable=False, initializer=vm_init)
vertical_labels = tf.matmul(float_labels, vm)
return tf.cast(vertical_labels > 0.2, tf.float32)
elif support_type == "frequent":
num_frequents = FLAGS.num_frequents
frequent_labels = tf.slice(labels, begin=[0, 0], size=[-1, num_frequents])
frequent_labels = tf.cast(frequent_labels, dtype=tf.float32)
return frequent_labels
elif support_type == "label":
float_labels = tf.cast(labels, dtype=tf.float32)
return float_labels
else:
raise NotImplementedError()
评论列表
文章目录