def predict(model_name, model, images_dir, image_ids, batch_size=64, tile_size=224):
x_test = np.zeros((len(image_ids), tile_size, tile_size, 3), dtype=np.float32)
for idx, image_name in tqdm(enumerate(image_ids), total=len(image_ids)):
# img = imread(join(images_dir, '{}.jpg'.format(image_name)))
image_path = join(images_dir, '{}.jpg'.format(image_name))
try:
img = cv2.imread(image_path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = np.asarray(cv2.resize(img, (tile_size, tile_size)), dtype=np.float32)
x_test[idx, ...] = img
except Exception as e:
print e.message
print 'image:', image_path
x_test = get_preprocess_input_fn(model_name)(x_test)
print(x_test.shape)
predictions = model.predict(x_test, batch_size=batch_size, verbose=1)
return predictions
评论列表
文章目录