dataset.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:irelia 作者: jireh-father 项目源码 文件源码
def make_dataset(self, filenames, batch_size, shuffle_buffer_size=100, num_dataset_parallel=4):
        def decode_line(line):
            items = tf.decode_csv(line, [[""], [""], [""]], field_delim=",")
            return items

        if len(filenames) > 1:
            dataset = tf.data.Dataset.from_tensor_slices(filenames)

            dataset = dataset.flat_map(
                lambda filename: (
                    tf.data.TextLineDataset(filename).map(decode_line, num_dataset_parallel)))
        else:
            dataset = tf.data.TextLineDataset(filenames).map(decode_line, num_dataset_parallel)

        if shuffle_buffer_size > 0:
            dataset = dataset.shuffle(shuffle_buffer_size)

        self.dataset_iterator = dataset.batch(batch_size).make_initializable_iterator()
        self.num_samples = Dataset.get_number_of_items(filenames)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号