classify_image.py 文件源码

python
阅读 21 收藏 0 点赞 0 评论 0

项目:games-cnn 作者: vanHavel 项目源码 文件源码
def classify_image(image_paths=['img.jpg'],
    model_path=os.path.join('model', 'model.mod'),
    cutoff_file='cutoffs.npy'):
    # load model
    model = load_model(model_path)

    # read genre file 
    genre_file_path = os.path.join('training_data', 'genres.txt')
    with open(genre_file_path, 'r') as handler:
        genres = handler.readlines()

    # determine preprocess method
    preprocess_path = os.path.join('training_data', 'preprocess.txt')
    with open(preprocess_path, 'r') as preprocess_file:
        dictionary = ast.literal_eval(preprocess_file.read())
        preprocess_method = dictionary['preprocess']
    if preprocess_method == 'xception':
        preprocess = preprocess_xception
    elif preprocess_method == 'vgg':
        preprocess = imagenet_utils.preprocess_input
    elif preprocess_method == 'none':
        preprocess = lambda x:x

    # preprocess images
    input_shape = model.layers[0].input_shape
    dimension = (input_shape[1], input_shape[2])
    screenshots = [process_screen(image_path, dimension, preprocess) for image_path in image_paths]

    # load cutoffs
    cutoffs = np.load(os.path.join('cutoffs', cutoff_file))

    # predict classes
    predictions = model.predict(np.array(screenshots))
    for prediction in predictions:
        print(prediction)
        classes = [i for i in range(0, len(prediction)) if prediction[i] >= cutoffs[i]]
        print('Predicted genres:')
        for c in classes:
            print(genres[c][:-1])
        print('True genres:')

# preprocess a single screen
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号