def validate(data, labels):
'''
Ten-fold cross-validation with stratified sampling.
'''
accuracy_scores = []
precision_scores = []
recall_scores = []
f1_scores = []
sss = StratifiedShuffleSplit(n_splits=10)
for train_index, test_index in sss.split(data[0], labels):
x_train_0, x_test_0 = data[0][train_index], data[0][test_index]
x_train_1, x_test_1 = data[1][train_index], data[1][test_index]
x_train_2, x_test_2 = data[2][train_index], data[2][test_index]
y_train, y_test = labels[train_index], labels[test_index]
model = create_model(data)
model.fit([x_train_0, x_train_1, x_train_2], y_train,
epochs=100, batch_size=128, class_weight=class_weight)
#y_pred = model.predict_classes(x_test, batch_size=128)
y_pred = model.predict([x_test_0, x_test_1, x_test_2], batch_size=128)
y_pred = [1 if y[0] > 0.5 else 0 for y in y_pred]
accuracy_scores.append(accuracy_score(y_test, y_pred))
precision_scores.append(precision_score(y_test, y_pred))
recall_scores.append(recall_score(y_test, y_pred))
f1_scores.append(f1_score(y_test, y_pred))
print('')
print('Accuracy', np.mean(accuracy_scores))
print('Precision', np.mean(precision_scores))
print('Recall', np.mean(recall_scores))
print('F1-measure', np.mean(f1_scores))
评论列表
文章目录