def get_data_kitti(datadir, shuffle_all, batchs):
"""Construct input data lists for Kitti 2012 Evaluation"""
sintel_imgs_1 = "image_2_crop/"
sintel_flows = "flow_occ_crop/"
with tf.name_scope('Input'):
# after number 154 image sizes change
list_0 = sorted(glob.glob(datadir + sintel_imgs_1 + '/*10.png'))
list_1 = sorted(glob.glob(datadir + sintel_imgs_1 + '/*11.png'))
flow_list = sorted(glob.glob(datadir + sintel_flows + '/*.png'))
print(len(list_0), len(list_1), len(flow_list))
print("Number of input length: " + str(len(list_0)))
assert len(list_0) == len(list_1) == len(
flow_list) != 0, ('Input Lengths not correct')
if shuffle_all:
p = np.random.permutation(len(list_0))
else:
p = np.arange(len(list_0))
list_0 = [list_0[i] for i in p]
list_1 = [list_1[i] for i in p]
flow_list = [flow_list[i] for i in p]
input_queue = tf.train.slice_input_producer(
[list_0, list_1, flow_list],
shuffle=False) # shuffled before
# image reader
content_0 = tf.read_file(input_queue[0])
content_1 = tf.read_file(input_queue[1])
content_flow = tf.read_file(input_queue[2])
imgs_0 = tf.image.decode_png(content_0, channels=3)
imgs_1 = tf.image.decode_png(content_1, channels=3)
imgs_0 = tf.image.convert_image_dtype(imgs_0, dtype=tf.float32)
imgs_1 = tf.image.convert_image_dtype(imgs_1, dtype=tf.float32)
flows = tf.cast(tf.image.decode_png(
content_flow, channels=3, dtype=tf.uint16), tf.float32)
# set shape
imgs_0.set_shape(FLAGS.img_shape)
imgs_1.set_shape(FLAGS.img_shape)
flows.set_shape(FLAGS.img_shape)
return tf.train.batch([imgs_0, imgs_1, flows],
batch_size=batchs
#,num_threads=1
)
评论列表
文章目录