def my_input(file_path, perform_shuffle=False, repeat_count=1):
"""
create an input function reading a file with the Dataset API
"""
def decode_csv(line):
parsed_line = tf.decode_csv(line, [[0.], [0.], [0.], [0.], [0]])
label = parsed_line[-1:]
del parsed_line[-1]
features = parsed_line
d = dict(zip(feature_names, features)), label
return d
dataset = (tf.data.TextLineDataset(file_path).skip(1).map(decode_csv))
if perform_shuffle:
dataset = dataset.shuffle(buffer_size=256)
dataset = dataset.repeat(repeat_count)
dataset = dataset.batch(32)
iterator = dataset.make_one_shot_iterator()
batch_features, batch_labels = iterator.get_next()
return batch_features, batch_labels
评论列表
文章目录