def get_data(img_folder, label_folder, train_fraction, img_size,
train_timesteps=4, test_timesteps=4, batch_size=1, sample_objects=False, n_threads=3,
in_memory=False, which_seqs=None, truncated_threshold=2., occluded_threshold=3., depth_folder=None,
storage_dtype=tf.uint8, mirror=False, reverse=False, bbox_scale=.5):
kitti = KittiTrackingParser(img_folder, label_folder, presence=True, id=False, cls=False,
truncated_threshold=truncated_threshold, occluded_threshold=occluded_threshold)
train, test = split_sequence_dict(kitti.data_dict, train_fraction)
def make_store(name, d, timesteps, n_threads, mirror=False, reverse=False):
s = KittiStore(d, timesteps, img_size, batch_size,
sample_objects=sample_objects, which_seqs=which_seqs, n_threads=n_threads,
in_memory=in_memory, depth_folder=depth_folder, storage_dtype=storage_dtype,
mirror=mirror, reverse=reverse, bbox_scale=bbox_scale, name=name)
return s
train_store = make_store('train', train, train_timesteps, n_threads, mirror, reverse)
test_store = make_store('test', test, test_timesteps, (n_threads // 2) + 1)
return train_store, train_store.get_minibatch(), test_store, test_store.get_minibatch()
评论列表
文章目录