citation_eval.py 文件源码

python
阅读 23 收藏 0 点赞 0 评论 0

项目:GraphSAGE 作者: williamleif 项目源码 文件源码
def run_regression(train_embeds, train_labels, test_embeds, test_labels):
    np.random.seed(1)
    from sklearn.linear_model import SGDClassifier
    from sklearn.dummy import DummyClassifier
    from sklearn.metrics import f1_score
    dummy = DummyClassifier()
    dummy.fit(train_embeds, train_labels)
    log = SGDClassifier(loss="log", n_jobs=10)
    log.fit(train_embeds, train_labels)
    print("F1 score:", f1_score(test_labels, log.predict(test_embeds), average="micro"))
    print("Random baseline f1 score:", f1_score(test_labels, dummy.predict(test_embeds), average="micro"))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号