ppi_eval.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:GraphSAGE 作者: williamleif 项目源码 文件源码
def run_regression(train_embeds, train_labels, test_embeds, test_labels):
    np.random.seed(1)
    from sklearn.linear_model import SGDClassifier
    from sklearn.dummy import DummyClassifier
    from sklearn.metrics import f1_score
    from sklearn.multioutput import MultiOutputClassifier
    dummy = MultiOutputClassifier(DummyClassifier())
    dummy.fit(train_embeds, train_labels)
    log = MultiOutputClassifier(SGDClassifier(loss="log"), n_jobs=10)
    log.fit(train_embeds, train_labels)

    f1 = 0
    for i in range(test_labels.shape[1]):
        print("F1 score", f1_score(test_labels[:,i], log.predict(test_embeds)[:,i], average="micro"))
    for i in range(test_labels.shape[1]):
        print("Random baseline F1 score", f1_score(test_labels[:,i], dummy.predict(test_embeds)[:,i], average="micro"))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号