utils.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:bird_classification 作者: halwai 项目源码 文件源码
def get_batch(generator_type, set_type, height, width):
    imgs = []
    if set_type == 'train' or set_type == 'val':
        for paths, bbs, labels in generator_type:
            for i  in range(len(paths)):
                img = gray2rgb(misc.imread(paths[i]), paths[i])
                img = img[bbs[i][1]:bbs[i][1]+bbs[i][3], bbs[i][0]:bbs[i][0]+bbs[i][2],:]
                img = preprocess_image(img, height, width, set_type)
                imgs.append(img)
            imgs = np.asarray(imgs)
            break
        return imgs, labels
    else:
        for paths, bbs in generator_type:
            for i  in range(len(paths)):
                img = gray2rgb(misc.imread(paths[i]), paths[i])
                img = img[bbs[i][1]:bbs[i][1]+bbs[i][3], bbs[i][0]:bbs[i][0]+bbs[i][2],:]
                imgs.append(preprocess_image(img, height, width, set_type))
            imgs = np.asarray(imgs)
            break
        return imgs, None



#store in required csv format
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号