train.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:chainer-spatial-transformer-networks 作者: hvy 项目源码 文件源码
def transform_mnist_rts(in_data):
    img, label = in_data
    img = img[0]  # Remove channel axis for skimage manipulation

    # Rotate
    img = transform.rotate(img, angle=np.random.uniform(-45, 45),
                           resize=True, mode='constant')
    #  Scale
    img = transform.rescale(img, scale=np.random.uniform(0.7, 1.2),
                            mode='constant')

    # Translate
    h, w = img.shape
    if h >= img_size[0] or w >= img_size[1]:
        img = transform.resize(img, output_shape=img_size, mode='constant')
        img = img.astype(np.float32)
    else:
        img_canvas = np.zeros(img_size, dtype=np.float32)
        ymin = np.random.randint(0, img_size[0] - h)
        xmin = np.random.randint(0, img_size[1] - w)
        img_canvas[ymin:ymin+h, xmin:xmin+w] = img
        img = img_canvas

    img = img[np.newaxis, :]  # Add the bach channel back
    return img, label
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号