ensemble.py 文件源码

python
阅读 24 收藏 0 点赞 0 评论 0

项目:qtim_ROP 作者: QTIM-Lab 项目源码 文件源码
def evaluate_ensemble(models_dir, test_images, out_dir, rf=False):

    # Get images and true classes
    img_arr, y_true = imgs_by_class_to_th_array(test_images, CLASS_LABELS)
    print img_arr.shape

    y_pred_all = []

    # Load each model
    for i, model_dir in enumerate(get_subdirs(models_dir)):

        # Load model
        if rf:
            print "Loading CNN+RF #{}".format(i)
            model_config, rf_pkl = locate_config(model_dir)
            model = RetinaRF(model_config, rf_pkl=rf_pkl)
        else:
            print "Loading CNN #{}".format(i)
            config_file = glob(join(model_dir, '*.yaml'))[0]
            model = RetiNet(config_file).model

        # Predicted probabilities
        print "Making predictions..."
        ypred_out = join(out_dir, 'ypred_{}.npy'.format(i))

        if not exists(ypred_out):
            y_preda = model.predict(img_arr)
            np.save(ypred_out, y_preda)
        else:
            y_preda = np.load(ypred_out)

        y_pred_all.append(y_preda)
        y_pred = np.argmax(y_preda, axis=1)

        kappa = cohen_kappa_score(y_true, y_pred, weights='quadratic')
        confusion(y_true, y_pred, CLASS_LABELS, join(out_dir, 'confusion_split{}_k={:.3f}.png'.format(i, kappa)))

    # Evaluate ensemble
    y_preda_ensemble = np.mean(np.dstack(y_pred_all), axis=2)
    y_pred_ensemble = np.argmax(y_preda_ensemble, axis=1)
    kappa = cohen_kappa_score(y_true, y_pred_ensemble)
    confusion(y_true, y_pred_ensemble, CLASS_LABELS, join(out_dir, 'confusion_ensemble_k={:.3f}.png'.format(kappa)))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号