def _produce_one_sample(self):
dirname = os.path.dirname(self.path)
if not check_dir(dirname):
raise ValueError("Invalid data path.")
with open(self.path, 'r') as fid:
flist = [l.strip() for l in fid.xreadlines()]
if self.shuffle:
random.shuffle(flist)
input_files = [os.path.join(dirname, 'input', f) for f in flist]
output_files = [os.path.join(dirname, 'output', f) for f in flist]
self.nsamples = len(input_files)
input_queue, output_queue = tf.train.slice_input_producer(
[input_files, output_files], shuffle=self.shuffle,
seed=0123, num_epochs=self.num_epochs)
if '16-bit' in magic.from_file(input_files[0]):
input_dtype = tf.uint16
input_wl = 65535.0
else:
input_wl = 255.0
input_dtype = tf.uint8
if '16-bit' in magic.from_file(output_files[0]):
output_dtype = tf.uint16
output_wl = 65535.0
else:
output_wl = 255.0
output_dtype = tf.uint8
input_file = tf.read_file(input_queue)
output_file = tf.read_file(output_queue)
if os.path.splitext(input_files[0])[-1] == '.jpg':
im_input = tf.image.decode_jpeg(input_file, channels=3)
else:
im_input = tf.image.decode_png(input_file, dtype=input_dtype, channels=3)
if os.path.splitext(output_files[0])[-1] == '.jpg':
im_output = tf.image.decode_jpeg(output_file, channels=3)
else:
im_output = tf.image.decode_png(output_file, dtype=output_dtype, channels=3)
# normalize input/output
sample = {}
with tf.name_scope('normalize_images'):
im_input = tf.to_float(im_input)/input_wl
im_output = tf.to_float(im_output)/output_wl
inout = tf.concat([im_input, im_output], 2)
fullres, inout = self._augment_data(inout, 6)
sample['lowres_input'] = inout[:, :, :3]
sample['lowres_output'] = inout[:, :, 3:]
sample['image_input'] = fullres[:, :, :3]
sample['image_output'] = fullres[:, :, 3:]
return sample
评论列表
文章目录