flownet_tools.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:Bayesian-FlowNet 作者: Johswald 项目源码 文件源码
def tensorflow_reader(list_0, list_1, flow_list, shuffle_all, batchs):
    """Average endpoint error between prediction and groundtruth

    Keyword arguments:
    list_0 -- source list of first of img pair
    list_1 -- source list of second of img pair
    flow_list -- source list of optical flow between first and second img
    shuffle_all -- boolean if list should be shuffled
    batchs -- batchsize
    """

    assert len(list_0) == len(list_1) == len(
        flow_list) != 0, ('Input Lengths not correct')

    print("Number of inputs: " + str(len(list_0)))
    if shuffle_all == True:
        p = np.random.permutation(len(list_0))
    else:
        p = np.arange(len(list_0))
    list_0 = [list_0[i] for i in p]
    list_1 = [list_1[i] for i in p]
    flow_list = [flow_list[i] for i in p]
    input_queue = tf.train.slice_input_producer(
        [list_0, list_1],
        shuffle=False)  # shuffled before
    # image reader
    content_0 = tf.read_file(input_queue[0])
    content_1 = tf.read_file(input_queue[1])
    imgs_0 = tf.image.decode_image(content_0, channels=3)
    imgs_1 = tf.image.decode_image(content_1, channels=3)
    # convert to [0, 1] images
    imgs_0 = tf.image.convert_image_dtype(imgs_0, dtype=tf.float32)
    imgs_1 = tf.image.convert_image_dtype(imgs_1, dtype=tf.float32)
    # flow reader
    filename_queue = tf.train.string_input_producer(flow_list, shuffle=False)
    record_bytes = FLAGS.record_bytes  # 1572876
    reader = tf.FixedLengthRecordReader(record_bytes=record_bytes)
    key, value = reader.read(filename_queue)
    record_bytes = tf.decode_raw(value, tf.float32)

    magic = tf.slice(record_bytes, [0], [1])  # .flo number 202021.25
    size = tf.slice(record_bytes, [1], [2])  # size of flow / image
    flows = tf.slice(record_bytes, [3], [np.prod(FLAGS.flow_shape)])
    flows = tf.reshape(flows, FLAGS.flow_shape)

    # set shape
    imgs_0.set_shape(FLAGS.img_shape)
    imgs_1.set_shape(FLAGS.img_shape)
    flows.set_shape(FLAGS.flow_shape)

    return tf.train.batch([imgs_0, imgs_1, flows],
                          batch_size=batchs
                          #,num_threads=1
                          )
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号