nn.py 文件源码

python
阅读 27 收藏 0 点赞 0 评论 0

项目:dac-training 作者: jlonij 项目源码 文件源码
def validate(data, labels):
    '''
    Ten-fold cross-validation with stratified sampling.
    '''
    accuracy_scores = []
    precision_scores = []
    recall_scores = []
    f1_scores = []

    sss = StratifiedShuffleSplit(n_splits=10)
    for train_index, test_index in sss.split(data, labels):
        x_train, x_test = data[train_index], data[test_index]
        y_train, y_test = labels[train_index], labels[test_index]

        model = load_model(data)
        model.fit(x_train, y_train, epochs=100, batch_size=128,
            class_weight=class_weight)
        y_pred = model.predict_classes(x_test, batch_size=128)

        accuracy_scores.append(accuracy_score(y_test, y_pred))
        precision_scores.append(precision_score(y_test, y_pred))
        recall_scores.append(recall_score(y_test, y_pred))
        f1_scores.append(f1_score(y_test, y_pred))

    print('')
    print('Accuracy', np.mean(accuracy_scores))
    print('Precision', np.mean(precision_scores))
    print('Recall', np.mean(recall_scores))
    print('F1-measure', np.mean(f1_scores))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号