label_type.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:sail 作者: GemHunt 项目源码 文件源码
def create_design_data_set(labeled_designs,design_crop_dir,image_dir,test):
    labels = ['heads','tails']
    if test:
        pixels_to_jitter = 0
        angles = 1
    else:
        pixels_to_jitter = 2
        angles = 100

    for label in labels:
        dir = image_dir + label + '/'
        if not os.path.exists(dir):
            os.makedirs(dir)

    for coin_id, label in labeled_designs.iteritems():
        before_rotate_size = 56
        for image_id in range(0,56):
            #dir = design_crop_dir + str(coin_id / 100) + '/'
            class_dir = image_dir + label + '/'
            #for angle in range(0,10):
            filename = str(coin_id).zfill(5) + str(image_id).zfill(2) + '.png'
            image = cv2.imread(design_crop_dir + filename)
            image = cv2.resize(image, (before_rotate_size, before_rotate_size), interpolation=cv2.INTER_AREA)
            for count in range(0,angles):
                angle = random.random() * 360
                center_x = before_rotate_size / 2 + (random.random() * pixels_to_jitter * 2) - pixels_to_jitter
                center_y = before_rotate_size / 2 + (random.random() * pixels_to_jitter * 2) - pixels_to_jitter
                rot_image = image.copy()
                m = cv2.getRotationMatrix2D((center_x, center_y), angle, 1)
                cv2.warpAffine(rot_image, m, (before_rotate_size, before_rotate_size), rot_image, cv2.INTER_CUBIC)
                # This is hard coded for 28x28.
                rot_image = cv2.resize(rot_image, (41, 41), interpolation=cv2.INTER_AREA)
                rot_image = rot_image[6:34, 6:34]
                rotated_filename = filename.replace('.png', str(count).zfill(2) + '.png')
                cv2.imwrite(class_dir + rotated_filename,rot_image)
    sys.exit()
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号