def get_label_queue(self,batch_size):
tf_labels = tf.convert_to_tensor(self.attr.values, dtype=tf.uint8)#0,1
with tf.name_scope('label_queue'):
uint_label=tf.train.slice_input_producer([tf_labels])[0]
label=tf.to_float(uint_label)
#All labels, not just those in causal_model
dict_data={sl:tl for sl,tl in
zip(self.label_names,tf.split(label,len(self.label_names)))}
num_preprocess_threads = max(self.num_worker-3,1)
data_batch = tf.train.shuffle_batch(
dict_data,
batch_size=batch_size,
num_threads=num_preprocess_threads,
capacity=self.min_queue_examples + 3 * batch_size,
min_after_dequeue=self.min_queue_examples,
)
return data_batch
评论列表
文章目录