def extract_batch(dataset, config):
with tf.device("/cpu:0"):
bboxer = PriorBoxGrid(config)
data_provider = slim.dataset_data_provider.DatasetDataProvider(
dataset, num_readers=2,
common_queue_capacity=512, common_queue_min=32)
if args.segment:
im, bbox, gt, seg = data_provider.get(['image', 'object/bbox', 'object/label',
'image/segmentation'])
else:
im, bbox, gt = data_provider.get(['image', 'object/bbox', 'object/label'])
seg = tf.expand_dims(tf.zeros(tf.shape(im)[:2]), 2)
im = tf.to_float(im)/255
bbox = yxyx_to_xywh(tf.clip_by_value(bbox, 0.0, 1.0))
im, bbox, gt, seg = data_augmentation(im, bbox, gt, seg, config)
inds, cats, refine = bboxer.encode_gt_tf(bbox, gt)
return tf.train.shuffle_batch([im, inds, refine, cats, seg],
args.batch_size, 2048, 64, num_threads=4)
评论列表
文章目录