classifier.py 文件源码

python
阅读 22 收藏 0 点赞 0 评论 0

项目:KATE 作者: hugochan 项目源码 文件源码
def multilabel_classifier(X_train, Y_train, X_val, Y_val, X_test, Y_test, nb_epoch=200, batch_size=10, seed=7):
    clf = sigmoid_network(X_train.shape[1], Y_train.shape[1])
    clf.fit(X_train, Y_train,
                        nb_epoch=nb_epoch,
                        batch_size=batch_size,
                        shuffle=True,
                        validation_data=(X_val, Y_val),
                        callbacks=[
                                    ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=3, min_lr=0.01),
                                    EarlyStopping(monitor='val_loss', min_delta=1e-5, patience=5, verbose=0, mode='auto'),
                        ]
                        )
    pred = clf.predict(X_test)
    pred = (pred > .5) * 1
    macro_f1 = f1_score(Y_test, pred, average='macro')
    micro_f1 = f1_score(Y_test, pred, average='micro')

    return [macro_f1, micro_f1]
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号