def multiclass_classifier(X_train, Y_train, X_val, Y_val, X_test, Y_test, nb_epoch=200, batch_size=10, seed=7):
clf = softmax_network(X_train.shape[1], Y_train.shape[1])
clf.fit(X_train, Y_train,
epochs=nb_epoch,
batch_size=batch_size,
shuffle=True,
validation_data=(X_val, Y_val),
callbacks=[
ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=3, min_lr=0.01),
EarlyStopping(monitor='val_loss', min_delta=1e-5, patience=5, verbose=0, mode='auto'),
]
)
acc = clf.test_on_batch(X_test, Y_test)[1]
# confusion matrix and precision-recall
true = np.argmax(Y_test,axis=1)
pred = np.argmax(clf.predict(X_test), axis=1)
print confusion_matrix(true, pred)
print classification_report(true, pred)
return acc
评论列表
文章目录