data_loader.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:CausalGAN 作者: mkocaoglu 项目源码 文件源码
def get_label_queue(self,batch_size):
        tf_labels = tf.convert_to_tensor(self.attr.values, dtype=tf.uint8)#0,1

        with tf.name_scope('label_queue'):
            uint_label=tf.train.slice_input_producer([tf_labels])[0]
        label=tf.to_float(uint_label)

        #All labels, not just those in causal_model
        dict_data={sl:tl for sl,tl in
                   zip(self.label_names,tf.split(label,len(self.label_names)))}


        num_preprocess_threads = max(self.num_worker-3,1)

        data_batch = tf.train.shuffle_batch(
                dict_data,
                batch_size=batch_size,
                num_threads=num_preprocess_threads,
                capacity=self.min_queue_examples + 3 * batch_size,
                min_after_dequeue=self.min_queue_examples,
                )

        return data_batch
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号