training.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:blitznet 作者: dvornikita 项目源码 文件源码
def extract_batch(dataset, config):
    with tf.device("/cpu:0"):
        bboxer = PriorBoxGrid(config)
        data_provider = slim.dataset_data_provider.DatasetDataProvider(
            dataset, num_readers=2,
            common_queue_capacity=512, common_queue_min=32)
        if args.segment:
            im, bbox, gt, seg = data_provider.get(['image', 'object/bbox', 'object/label',
                                                   'image/segmentation'])
        else:
            im, bbox, gt = data_provider.get(['image', 'object/bbox', 'object/label'])
            seg = tf.expand_dims(tf.zeros(tf.shape(im)[:2]), 2)
        im = tf.to_float(im)/255
        bbox = yxyx_to_xywh(tf.clip_by_value(bbox, 0.0, 1.0))

        im, bbox, gt, seg = data_augmentation(im, bbox, gt, seg, config)
        inds, cats, refine = bboxer.encode_gt_tf(bbox, gt)

        return tf.train.shuffle_batch([im, inds, refine, cats, seg],
                                      args.batch_size, 2048, 64, num_threads=4)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号