predict_ensemble_voting.py 文件源码

python
阅读 19 收藏 0 点赞 0 评论 0

项目:tefla 作者: litan 项目源码 文件源码
def predict(model, training_cnf, predict_dir, weights_from, dataset_name, convert, image_size, sync,
            predict_type):
    images = data.get_image_files(predict_dir)

    # Form now, hard coded models, cnfs, and weights
    # Need to take these from program inputs or an ensembling config file

    print('Creating predictor 1')
    weights_from1 = 'weights.sa/model-epoch-97.ckpt'
    model1 = 'examples/mnist_model_sa.py'
    training_cnf1 = 'examples/mnist_cnf.py'
    model_def1 = util.load_module(model1)
    model1 = model_def1.model
    cnf1 = util.load_module(training_cnf1).cnf
    standardizer = cnf1.get('standardizer', NoOpStandardizer())
    preprocessor = convert_preprocessor(model_def1.image_size[0]) if convert else None
    prediction_iterator1 = create_prediction_iter(cnf1, standardizer, model_def1.crop_size, preprocessor, sync)
    # predictor1 = QuasiCropPredictor(model1, cnf1, weights_from1, prediction_iterator1, 20)
    predictor1 = OneCropPredictor(model1, cnf1, weights_from1, prediction_iterator1)

    print('Creating predictor 2')
    weights_from2 = 'weights.rv/model-epoch-31.ckpt'
    model2 = 'examples/mnist_model.py'
    training_cnf2 = 'examples/mnist_cnf.py'
    model_def2 = util.load_module(model2)
    model2 = model_def2.model
    cnf2 = util.load_module(training_cnf2).cnf
    standardizer = cnf2.get('standardizer', NoOpStandardizer())
    preprocessor = convert_preprocessor(model_def2.image_size[0]) if convert else None
    prediction_iterator2 = create_prediction_iter(cnf2, standardizer, model_def2.crop_size, preprocessor, sync)
    # predictor2 = QuasiCropPredictor(model2, cnf2, weights_from2, prediction_iterator2, 20)
    predictor2 = OneCropPredictor(model2, cnf2, weights_from2, prediction_iterator2)

    predictor = EnsemblePredictor([predictor1, predictor2])

    def softmax_result_to_vote(predictions):
        return predictions.argmax(axis=1)

    def vote_combiner(votes):
        return mode(votes, axis=0)[0].reshape(-1)

    class_predictions = predictor.predict_with_voting(
        images,
        [softmax_result_to_vote, softmax_result_to_vote],
        vote_combiner
    )

    if not os.path.exists(os.path.join(predict_dir, '..', 'results')):
        os.mkdir(os.path.join(predict_dir, '..', 'results'))
    if not os.path.exists(os.path.join(predict_dir, '..', 'results', dataset_name)):
        os.mkdir(os.path.join(predict_dir, '..', 'results', dataset_name))

    names = data.get_names(images)
    image_class_predictions = np.column_stack([names, class_predictions])
    title = np.array(['image', 'label'])
    image_class_predictions = np.vstack([title, image_class_predictions])
    prediction_class_file = os.path.abspath(
        os.path.join(predict_dir, '..', 'results', dataset_name, 'predictions_class.csv'))
    np.savetxt(prediction_class_file, image_class_predictions, delimiter=",", fmt="%s")
    print('Class predictions saved to: %s' % prediction_class_file)
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号