utils_tf.py 文件源码

python
阅读 38 收藏 0 点赞 0 评论 0

项目:blitznet 作者: dvornikita 项目源码 文件源码
def data_augmentation(img, gt_bboxes, gt_cats, seg, config):
    params = config['train_augmentation']
    img = apply_with_random_selector(
        img,
        lambda x, ordering: photometric_distortions(x, ordering, params),
        num_cases=4)

    if seg is not None:
        img = tf.concat([img, tf.cast(seg, tf.float32)], axis=-1)

    img, gt_bboxes, gt_cats = scale_distortions(img, gt_bboxes, gt_cats,
                                                params)
    img, gt_bboxes = mirror_distortions(img, gt_bboxes, params)
    # XXX reference implementation also randomizes interpolation method
    img_size = config['image_size']
    img_out = tf.image.resize_images(img[..., :3], [img_size, img_size])
    gt_bboxes, gt_cats = filter_small_gt(gt_bboxes, gt_cats, 2/config['image_size'])

    if seg is not None:
        seg_shape = config['fm_sizes'][0]
        seg = tf.expand_dims(tf.expand_dims(img[..., 3], 0), -1)
        seg = tf.squeeze(tf.image.resize_nearest_neighbor(seg, [seg_shape, seg_shape]))
        seg = tf.cast(tf.round(seg), tf.int64)
    return img_out, gt_bboxes, gt_cats, seg
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号