classifier.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:KATE 作者: hugochan 项目源码 文件源码
def multiclass_classifier(X_train, Y_train, X_val, Y_val, X_test, Y_test, nb_epoch=200, batch_size=10, seed=7):
    clf = softmax_network(X_train.shape[1], Y_train.shape[1])
    clf.fit(X_train, Y_train,
                        epochs=nb_epoch,
                        batch_size=batch_size,
                        shuffle=True,
                        validation_data=(X_val, Y_val),
                        callbacks=[
                                    ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=3, min_lr=0.01),
                                    EarlyStopping(monitor='val_loss', min_delta=1e-5, patience=5, verbose=0, mode='auto'),
                        ]
                        )
    acc = clf.test_on_batch(X_test, Y_test)[1]
    # confusion matrix and precision-recall
    true = np.argmax(Y_test,axis=1)
    pred = np.argmax(clf.predict(X_test), axis=1)
    print confusion_matrix(true, pred)
    print classification_report(true, pred)
    return acc
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号