def make_dataset(self, filenames, batch_size, shuffle_buffer_size=100, num_dataset_parallel=4):
def decode_line(line):
items = tf.decode_csv(line, [[""], [""], [""]], field_delim=",")
return items
if len(filenames) > 1:
dataset = tf.data.Dataset.from_tensor_slices(filenames)
dataset = dataset.flat_map(
lambda filename: (
tf.data.TextLineDataset(filename).map(decode_line, num_dataset_parallel)))
else:
dataset = tf.data.TextLineDataset(filenames).map(decode_line, num_dataset_parallel)
if shuffle_buffer_size > 0:
dataset = dataset.shuffle(shuffle_buffer_size)
self.dataset_iterator = dataset.batch(batch_size).make_initializable_iterator()
self.num_samples = Dataset.get_number_of_items(filenames)
评论列表
文章目录