utils.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:uct_atari 作者: 5vision 项目源码 文件源码
def augment_image(image):

    # move channel to the last axis
    image = np.rollaxis(image, 0, 3)
    h, w, ch = image.shape[:3]

    # brightness
    brightness = random.uniform(-0.1, 0.1)

    # rotation and scaling
    rot = 1
    scale = 0.01
    Mrot = cv2.getRotationMatrix2D((h / 2, w / 2), random.uniform(-rot, rot), random.uniform(1.0 - scale, 1.0 + scale))

    # affine transform and shifts
    pts1 = np.float32([[0, 0], [w, 0], [w, h]])
    a = 1
    shift = 1
    shiftx = random.randint(-shift, shift)
    shifty = random.randint(-shift, shift)
    pts2 = np.float32([[
        0 + random.randint(-a, a) + shiftx,
        0 + random.randint(-a, a) + shifty
    ], [
        w + random.randint(-a, a) + shiftx,
        0 + random.randint(-a, a) + shifty
    ], [
        w + random.randint(-a, a) + shiftx,
        h + random.randint(-a, a) + shifty
    ]])
    M = cv2.getAffineTransform(pts1, pts2)

    def _augment(image):
        image = np.add(image, brightness)

        augmented = cv2.warpAffine(
            cv2.warpAffine(
                image
                , Mrot, (w, h)
            )
            , M, (w, h)
        )

        if augmented.ndim < 3:
            augmented = np.expand_dims(augmented, 2)

        return augmented

    # make same transform for each channel, splitting image by four channels
    image_lst = [image[..., i:i+4] for i in xrange(0, ch, 4)]
    augmented_lst = map(_augment, image_lst)
    augmented = np.concatenate(augmented_lst, axis=-1)

    # roll channel axis back when returning
    augmented = np.rollaxis(augmented, 2, 0)

    return augmented
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号