bnn.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:dac-training 作者: jlonij 项目源码 文件源码
def validate(data, labels):
    '''
    Ten-fold cross-validation with stratified sampling.
    '''
    accuracy_scores = []
    precision_scores = []
    recall_scores = []
    f1_scores = []

    sss = StratifiedShuffleSplit(n_splits=10)

    for train_index, test_index in sss.split(data[0], labels):
        x_train_0, x_test_0 = data[0][train_index], data[0][test_index]
        x_train_1, x_test_1 = data[1][train_index], data[1][test_index]
        x_train_2, x_test_2 = data[2][train_index], data[2][test_index]

        y_train, y_test = labels[train_index], labels[test_index]

        model = create_model(data)
        model.fit([x_train_0, x_train_1, x_train_2], y_train,
            epochs=100, batch_size=128, class_weight=class_weight)
        #y_pred = model.predict_classes(x_test, batch_size=128)

        y_pred = model.predict([x_test_0, x_test_1, x_test_2], batch_size=128)
        y_pred = [1 if y[0] > 0.5 else 0 for y in y_pred]

        accuracy_scores.append(accuracy_score(y_test, y_pred))
        precision_scores.append(precision_score(y_test, y_pred))
        recall_scores.append(recall_score(y_test, y_pred))
        f1_scores.append(f1_score(y_test, y_pred))

    print('')
    print('Accuracy', np.mean(accuracy_scores))
    print('Precision', np.mean(precision_scores))
    print('Recall', np.mean(recall_scores))
    print('F1-measure', np.mean(f1_scores))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号