def classify_image(image_paths=['img.jpg'],
model_path=os.path.join('model', 'model.mod'),
cutoff_file='cutoffs.npy'):
# load model
model = load_model(model_path)
# read genre file
genre_file_path = os.path.join('training_data', 'genres.txt')
with open(genre_file_path, 'r') as handler:
genres = handler.readlines()
# determine preprocess method
preprocess_path = os.path.join('training_data', 'preprocess.txt')
with open(preprocess_path, 'r') as preprocess_file:
dictionary = ast.literal_eval(preprocess_file.read())
preprocess_method = dictionary['preprocess']
if preprocess_method == 'xception':
preprocess = preprocess_xception
elif preprocess_method == 'vgg':
preprocess = imagenet_utils.preprocess_input
elif preprocess_method == 'none':
preprocess = lambda x:x
# preprocess images
input_shape = model.layers[0].input_shape
dimension = (input_shape[1], input_shape[2])
screenshots = [process_screen(image_path, dimension, preprocess) for image_path in image_paths]
# load cutoffs
cutoffs = np.load(os.path.join('cutoffs', cutoff_file))
# predict classes
predictions = model.predict(np.array(screenshots))
for prediction in predictions:
print(prediction)
classes = [i for i in range(0, len(prediction)) if prediction[i] >= cutoffs[i]]
print('Predicted genres:')
for c in classes:
print(genres[c][:-1])
print('True genres:')
# preprocess a single screen
评论列表
文章目录