def inference(model_name, weight_file, image_size, image_list, data_dir, label_dir, return_results=True, save_dir=None,
label_suffix='.png',
data_suffix='.jpg'):
current_dir = os.path.dirname(os.path.realpath(__file__))
# mean_value = np.array([104.00699, 116.66877, 122.67892])
batch_shape = (1, ) + image_size + (3, )
save_path = os.path.join(current_dir, 'Models/'+model_name)
model_path = os.path.join(save_path, "model.json")
checkpoint_path = os.path.join(save_path, weight_file)
# model_path = os.path.join(current_dir, 'model_weights/fcn_atrous/model_change.hdf5')
# model = FCN_Resnet50_32s((480,480,3))
config = tf.ConfigProto(gpu_options=tf.GPUOptions(allow_growth=True))
session = tf.Session(config=config)
K.set_session(session)
model = globals()[model_name](batch_shape=batch_shape, input_shape=(512, 512, 3))
model.load_weights(checkpoint_path, by_name=True)
model.summary()
results = []
total = 0
for img_num in image_list:
img_num = img_num.strip('\n')
total += 1
print('#%d: %s' % (total,img_num))
image = Image.open('%s/%s%s' % (data_dir, img_num, data_suffix))
image = img_to_array(image) # , data_format='default')
label = Image.open('%s/%s%s' % (label_dir, img_num, label_suffix))
label_size = label.size
img_h, img_w = image.shape[0:2]
# long_side = max(img_h, img_w, image_size[0], image_size[1])
pad_w = max(image_size[1] - img_w, 0)
pad_h = max(image_size[0] - img_h, 0)
image = np.lib.pad(image, ((pad_h/2, pad_h - pad_h/2), (pad_w/2, pad_w - pad_w/2), (0, 0)), 'constant', constant_values=0.)
# image -= mean_value
'''img = array_to_img(image, 'channels_last', scale=False)
img.show()
exit()'''
# image = cv2.resize(image, image_size)
image = np.expand_dims(image, axis=0)
image = preprocess_input(image)
result = model.predict(image, batch_size=1)
result = np.argmax(np.squeeze(result), axis=-1).astype(np.uint8)
result_img = Image.fromarray(result, mode='P')
result_img.palette = label.palette
# result_img = result_img.resize(label_size, resample=Image.BILINEAR)
result_img = result_img.crop((pad_w/2, pad_h/2, pad_w/2+img_w, pad_h/2+img_h))
# result_img.show(title='result')
if return_results:
results.append(result_img)
if save_dir:
result_img.save(os.path.join(save_dir, img_num + '.png'))
return results
评论列表
文章目录