kerasext.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:kaggle_amazon 作者: asanakoy 项目源码 文件源码
def predict(model_name, model, images_dir, image_ids, batch_size=64, tile_size=224):
    x_test = np.zeros((len(image_ids), tile_size, tile_size, 3), dtype=np.float32)

    for idx, image_name in tqdm(enumerate(image_ids), total=len(image_ids)):
        # img = imread(join(images_dir, '{}.jpg'.format(image_name)))
        image_path = join(images_dir, '{}.jpg'.format(image_name))
        try:
            img = cv2.imread(image_path)
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            img = np.asarray(cv2.resize(img, (tile_size, tile_size)), dtype=np.float32)
            x_test[idx, ...] = img
        except Exception as e:
            print e.message
            print 'image:', image_path
    x_test = get_preprocess_input_fn(model_name)(x_test)
    print(x_test.shape)
    predictions = model.predict(x_test, batch_size=batch_size, verbose=1)
    return predictions
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号