def data_loader(csv_filename: str, params: Params, batch_size: int=128, data_augmentation: bool=False,
num_epochs: int=None, image_summaries: bool=False):
def input_fn():
# Choose case one csv file or list of csv files
if not isinstance(csv_filename, list):
filename_queue = tf.train.string_input_producer([csv_filename], num_epochs=num_epochs, name='filename_queue')
elif isinstance(csv_filename, list):
filename_queue = tf.train.string_input_producer(csv_filename, num_epochs=num_epochs, name='filename_queue')
# Skip lines that have already been processed
reader = tf.TextLineReader(name='CSV_Reader', skip_header_lines=0)
key, value = reader.read(filename_queue, name='file_reading_op')
default_line = [['None'], ['None']]
path, label = tf.decode_csv(value, record_defaults=default_line, field_delim=params.csv_delimiter,
name='csv_reading_op')
image, img_width = image_reading(path, resized_size=params.input_shape,
data_augmentation=data_augmentation, padding=True)
to_batch = {'images': image, 'images_widths': img_width, 'filenames': path, 'labels': label}
prepared_batch = tf.train.shuffle_batch(to_batch,
batch_size=batch_size,
min_after_dequeue=500,
num_threads=15, capacity=4000,
allow_smaller_final_batch=False,
name='prepared_batch_queue')
if image_summaries:
tf.summary.image('input/image', prepared_batch.get('images'), max_outputs=1)
tf.summary.text('input/labels', prepared_batch.get('labels')[:10])
tf.summary.text('input/widths', tf.as_string(prepared_batch.get('images_widths')))
return prepared_batch, prepared_batch.get('labels')
return input_fn
评论列表
文章目录